diff --git a/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/LanceLakeCatalog.java b/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/LanceLakeCatalog.java index 600dcbd0d9..ad8f7687a4 100644 --- a/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/LanceLakeCatalog.java +++ b/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/LanceLakeCatalog.java @@ -61,7 +61,9 @@ public void createTable(TablePath tablePath, TableDescriptor tableDescriptor, Co List fields = new ArrayList<>(); // set schema fields.addAll( - LanceArrowUtils.toArrowSchema(tableDescriptor.getSchema().getRowType()) + LanceArrowUtils.toArrowSchema( + tableDescriptor.getSchema().getRowType(), + tableDescriptor.getCustomProperties()) .getFields()); try { LanceDatasetAdapter.createDataset(config.getDatasetUri(), new Schema(fields), params); diff --git a/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/utils/ArrowDataConverter.java b/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/utils/ArrowDataConverter.java index ed822e8b2e..8091379a6f 100644 --- a/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/utils/ArrowDataConverter.java +++ b/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/utils/ArrowDataConverter.java @@ -21,6 +21,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.types.pojo.Schema; @@ -54,16 +55,21 @@ public static VectorSchemaRoot convertToNonShaded( VectorSchemaRoot.create(nonShadedSchema, nonShadedAllocator); nonShadedRoot.allocateNew(); - List shadedVectors = - shadedRoot.getFieldVectors(); - List nonShadedVectors = nonShadedRoot.getFieldVectors(); + try { + List shadedVectors = + shadedRoot.getFieldVectors(); + List nonShadedVectors = nonShadedRoot.getFieldVectors(); - for (int i = 0; i < shadedVectors.size(); i++) { - copyVectorData(shadedVectors.get(i), nonShadedVectors.get(i)); - } + for (int i = 0; i < shadedVectors.size(); i++) { + copyVectorData(shadedVectors.get(i), nonShadedVectors.get(i)); + } - nonShadedRoot.setRowCount(shadedRoot.getRowCount()); - return nonShadedRoot; + nonShadedRoot.setRowCount(shadedRoot.getRowCount()); + return nonShadedRoot; + } catch (Exception e) { + nonShadedRoot.close(); + throw e; + } } private static void copyVectorData( @@ -71,14 +77,21 @@ private static void copyVectorData( FieldVector nonShadedVector) { if (shadedVector - instanceof - org.apache.fluss.shaded.arrow.org.apache.arrow.vector.complex.ListVector - && nonShadedVector instanceof ListVector) { - copyListVectorData( - (org.apache.fluss.shaded.arrow.org.apache.arrow.vector.complex.ListVector) - shadedVector, - (ListVector) nonShadedVector); - return; + instanceof + org.apache.fluss.shaded.arrow.org.apache.arrow.vector.complex.ListVector) { + if (nonShadedVector instanceof FixedSizeListVector) { + copyListToFixedSizeListVectorData( + (org.apache.fluss.shaded.arrow.org.apache.arrow.vector.complex.ListVector) + shadedVector, + (FixedSizeListVector) nonShadedVector); + return; + } else if (nonShadedVector instanceof ListVector) { + copyListVectorData( + (org.apache.fluss.shaded.arrow.org.apache.arrow.vector.complex.ListVector) + shadedVector, + (ListVector) nonShadedVector); + return; + } } List shadedBuffers = @@ -143,4 +156,57 @@ private static void copyListVectorData( // For ListVector, we need to manually set lastSet to avoid offset buffer recalculation nonShadedListVector.setLastSet(valueCount - 1); } + + private static void copyListToFixedSizeListVectorData( + org.apache.fluss.shaded.arrow.org.apache.arrow.vector.complex.ListVector + shadedListVector, + FixedSizeListVector nonShadedFixedSizeListVector) { + + int valueCount = shadedListVector.getValueCount(); + int expectedListSize = nonShadedFixedSizeListVector.getListSize(); + int expectedTotalValueCount = valueCount * expectedListSize; + + // Validate that backing data vector element count matches expected fixed-size layout. + // If every list has exactly expectedListSize elements, the total must be + // valueCount * expectedListSize. + int totalValueCount = shadedListVector.getDataVector().getValueCount(); + if (totalValueCount != expectedTotalValueCount) { + throw new IllegalArgumentException( + String.format( + "Total child elements (%d) does not match expected %d for FixedSizeList conversion.", + totalValueCount, expectedTotalValueCount)); + } + + // Copy the child data vector recursively (e.g., the float values) + org.apache.fluss.shaded.arrow.org.apache.arrow.vector.FieldVector shadedDataVector = + shadedListVector.getDataVector(); + FieldVector nonShadedDataVector = nonShadedFixedSizeListVector.getDataVector(); + + if (shadedDataVector != null && nonShadedDataVector != null) { + copyVectorData(shadedDataVector, nonShadedDataVector); + } + + // FixedSizeListVector only has a validity buffer (no offset buffer). + // Copy the validity buffer from the shaded ListVector. + List shadedBuffers = + shadedListVector.getFieldBuffers(); + List nonShadedBuffers = nonShadedFixedSizeListVector.getFieldBuffers(); + + // Both ListVector and FixedSizeListVector have validity as their first buffer + if (!shadedBuffers.isEmpty() && !nonShadedBuffers.isEmpty()) { + org.apache.fluss.shaded.arrow.org.apache.arrow.memory.ArrowBuf shadedValidityBuf = + shadedBuffers.get(0); + ArrowBuf nonShadedValidityBuf = nonShadedBuffers.get(0); + + long size = Math.min(shadedValidityBuf.capacity(), nonShadedValidityBuf.capacity()); + if (size > 0) { + ByteBuffer srcBuffer = shadedValidityBuf.nioBuffer(0, (int) size); + srcBuffer.position(0); + srcBuffer.limit((int) Math.min(size, Integer.MAX_VALUE)); + nonShadedValidityBuf.setBytes(0, srcBuffer); + } + } + + nonShadedFixedSizeListVector.setValueCount(valueCount); + } } diff --git a/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/utils/LanceArrowUtils.java b/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/utils/LanceArrowUtils.java index 681367a327..e2c7d814a8 100644 --- a/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/utils/LanceArrowUtils.java +++ b/fluss-lake/fluss-lake-lance/src/main/java/org/apache/fluss/lake/lance/utils/LanceArrowUtils.java @@ -47,31 +47,76 @@ import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; +import static org.apache.fluss.utils.Preconditions.checkArgument; + /** * Utilities for converting Fluss RowType to non-shaded Arrow Schema. This is needed because Lance * requires non-shaded Arrow API. */ public class LanceArrowUtils { + /** Property suffix for configuring a fixed-size list Arrow type on array columns. */ + public static final String FIXED_SIZE_LIST_SIZE_SUFFIX = ".arrow.fixed-size-list.size"; + /** Returns the non-shaded Arrow schema of the specified Fluss RowType. */ public static Schema toArrowSchema(RowType rowType) { + return toArrowSchema(rowType, Collections.emptyMap()); + } + + /** + * Returns the non-shaded Arrow schema of the specified Fluss RowType, using table properties to + * determine whether array columns should use FixedSizeList instead of List. + * + *

When a table property {@code .arrow.fixed-size-list.size} is set, the + * corresponding ARRAY column will be emitted as {@code FixedSizeList(size)} instead of + * {@code List}. + */ + public static Schema toArrowSchema(RowType rowType, Map tableProperties) { List fields = rowType.getFields().stream() - .map(f -> toArrowField(f.getName(), f.getType())) + .map(f -> toArrowField(f.getName(), f.getType(), tableProperties)) .collect(Collectors.toList()); return new Schema(fields); } - private static Field toArrowField(String fieldName, DataType logicalType) { - FieldType fieldType = - new FieldType(logicalType.isNullable(), toArrowType(logicalType), null); + private static Field toArrowField( + String fieldName, DataType logicalType, Map tableProperties) { + ArrowType arrowType; + if (logicalType instanceof ArrayType && tableProperties != null) { + String sizeStr = tableProperties.get(fieldName + FIXED_SIZE_LIST_SIZE_SUFFIX); + if (sizeStr != null) { + int listSize = -1; + try { + listSize = Integer.parseInt(sizeStr); + } catch (NumberFormatException ignored) { + // Not really ignored, IllegalArgumentEx still thrown below. + // This removes duplicate boilerplates for throwing IAE + } + + checkArgument( + listSize > 0, + "Invalid value '%s' for property '%s'. Expected a positive integer.", + sizeStr, + fieldName + FIXED_SIZE_LIST_SIZE_SUFFIX); + arrowType = new ArrowType.FixedSizeList(listSize); + } else { + arrowType = toArrowType(logicalType); + } + } else { + arrowType = toArrowType(logicalType); + } + FieldType fieldType = new FieldType(logicalType.isNullable(), arrowType, null); List children = null; if (logicalType instanceof ArrayType) { children = Collections.singletonList( - toArrowField("element", ((ArrayType) logicalType).getElementType())); + toArrowField( + "element", + ((ArrayType) logicalType).getElementType(), + tableProperties)); } return new Field(fieldName, fieldType, children); } diff --git a/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/tiering/LanceTieringTest.java b/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/tiering/LanceTieringTest.java index 59ca5b5178..a7a13ff5e2 100644 --- a/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/tiering/LanceTieringTest.java +++ b/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/tiering/LanceTieringTest.java @@ -37,6 +37,7 @@ import org.apache.fluss.record.GenericRecord; import org.apache.fluss.record.LogRecord; import org.apache.fluss.row.BinaryString; +import org.apache.fluss.row.GenericArray; import org.apache.fluss.row.GenericRow; import org.apache.fluss.types.DataTypes; import org.apache.fluss.utils.types.Tuple2; @@ -46,7 +47,9 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.params.ParameterizedTest; @@ -69,6 +72,8 @@ /** The UT for tiering to Lance via {@link LanceLakeTieringFactory}. */ class LanceTieringTest { + private static final int EMBEDDING_LIST_SIZE = 4; + private @TempDir File tempWarehouseDir; private LanceLakeTieringFactory lanceLakeTieringFactory; private Configuration configuration; @@ -91,13 +96,16 @@ void testTieringWriteTable(boolean isPartitioned) throws Exception { TablePath tablePath = TablePath.of("lance", "logTable"); Map customProperties = new HashMap<>(); customProperties.put("lance.batch_size", "256"); + customProperties.put( + "embedding" + LanceArrowUtils.FIXED_SIZE_LIST_SIZE_SUFFIX, + String.valueOf(EMBEDDING_LIST_SIZE)); LanceConfig config = LanceConfig.from( configuration.toMap(), customProperties, tablePath.getDatabaseName(), tablePath.getTableName()); - Schema schema = createTable(config); + Schema schema = createTable(config, customProperties); TableDescriptor descriptor = TableDescriptor.builder() @@ -180,6 +188,13 @@ void testTieringWriteTable(boolean isPartitioned) throws Exception { new RootAllocator(), config.getDatasetUri(), LanceConfig.genReadOptionFromConfig(config))) { + // verify the embedding column uses FixedSizeList in the Lance schema + org.apache.arrow.vector.types.pojo.Field embeddingField = + dataset.getSchema().findField("embedding"); + assertThat(embeddingField.getType()).isInstanceOf(ArrowType.FixedSizeList.class); + assertThat(((ArrowType.FixedSizeList) embeddingField.getType()).getListSize()) + .isEqualTo(EMBEDDING_LIST_SIZE); + ArrowReader reader = dataset.newScan().scanBatches(); VectorSchemaRoot readerRoot = reader.getVectorSchemaRoot(); @@ -189,8 +204,7 @@ void testTieringWriteTable(boolean isPartitioned) throws Exception { reader.loadNextBatch(); Tuple2 partitionBucket = Tuple2.of(partition, bucket); List expectRecords = recordsByBucket.get(partitionBucket); - verifyLogTableRecords( - readerRoot, expectRecords, bucket, isPartitioned, partition); + verifyLogTableRecords(readerRoot, expectRecords); } } assertThat(reader.loadNextBatch()).isFalse(); @@ -216,14 +230,13 @@ void testTieringWriteTable(boolean isPartitioned) throws Exception { } } - private void verifyLogTableRecords( - VectorSchemaRoot root, - List expectRecords, - int expectBucket, - boolean isPartitioned, - @Nullable String partition) - throws Exception { + private void verifyLogTableRecords(VectorSchemaRoot root, List expectRecords) { assertThat(root.getRowCount()).isEqualTo(expectRecords.size()); + + // verify the embedding vector is a FixedSizeListVector + assertThat(root.getVector("embedding")).isInstanceOf(FixedSizeListVector.class); + FixedSizeListVector embeddingVector = (FixedSizeListVector) root.getVector("embedding"); + for (int i = 0; i < expectRecords.size(); i++) { LogRecord expectRecord = expectRecords.get(i); // check business columns: @@ -233,6 +246,13 @@ private void verifyLogTableRecords( .isEqualTo(expectRecord.getRow().getString(1).toString()); assertThat(((VarCharVector) root.getVector(2)).getObject(i).toString()) .isEqualTo(expectRecord.getRow().getString(2).toString()); + // check embedding column + java.util.List embeddingValues = embeddingVector.getObject(i); + assertThat(embeddingValues).hasSize(EMBEDDING_LIST_SIZE); + org.apache.fluss.row.InternalArray expectedArray = expectRecord.getRow().getArray(3); + for (int j = 0; j < EMBEDDING_LIST_SIZE; j++) { + assertThat((Float) embeddingValues.get(j)).isEqualTo(expectedArray.getFloat(j)); + } } } @@ -296,19 +316,21 @@ private Tuple2, List> genLogTableRecords( List logRecords = new ArrayList<>(); for (int i = 0; i < numRecords; i++) { GenericRow genericRow; - if (partition != null) { - // Partitioned table: include partition field in data - genericRow = new GenericRow(3); // c1, c2, c3(partition) - genericRow.setField(0, i); - genericRow.setField(1, BinaryString.fromString("bucket" + bucket + "_" + i)); - genericRow.setField(2, BinaryString.fromString(partition)); // partition field - } else { - // Non-partitioned table - genericRow = new GenericRow(3); - genericRow.setField(0, i); - genericRow.setField(1, BinaryString.fromString("bucket" + bucket + "_" + i)); + + // Partitioned table: include partition field in data + genericRow = new GenericRow(4); // c1, c2, c3(partition), embedding + genericRow.setField(0, i); + genericRow.setField(1, BinaryString.fromString("bucket" + bucket + "_" + i)); + + if (partition == null) { genericRow.setField(2, BinaryString.fromString("bucket" + bucket)); + } else { + genericRow.setField(2, BinaryString.fromString(partition)); } + + genericRow.setField( + 3, new GenericArray(new float[] {0.1f * i, 0.2f * i, 0.3f * i, 0.4f * i})); + LogRecord logRecord = new GenericRecord( i, System.currentTimeMillis(), ChangeType.APPEND_ONLY, genericRow); @@ -317,16 +339,19 @@ private Tuple2, List> genLogTableRecords( return Tuple2.of(logRecords, logRecords); } - private Schema createTable(LanceConfig config) { + private Schema createTable(LanceConfig config, Map customProperties) { List columns = new ArrayList<>(); columns.add(new Schema.Column("c1", DataTypes.INT())); columns.add(new Schema.Column("c2", DataTypes.STRING())); columns.add(new Schema.Column("c3", DataTypes.STRING())); + columns.add(new Schema.Column("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); Schema.Builder schemaBuilder = Schema.newBuilder().fromColumns(columns); Schema schema = schemaBuilder.build(); WriteParams params = LanceConfig.genWriteParamsFromConfig(config); LanceDatasetAdapter.createDataset( - config.getDatasetUri(), LanceArrowUtils.toArrowSchema(schema.getRowType()), params); + config.getDatasetUri(), + LanceArrowUtils.toArrowSchema(schema.getRowType(), customProperties), + params); return schema; } diff --git a/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/utils/ArrowDataConverterTest.java b/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/utils/ArrowDataConverterTest.java new file mode 100644 index 0000000000..f8fbd8b383 --- /dev/null +++ b/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/utils/ArrowDataConverterTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.lake.lance.utils; + +import org.apache.fluss.lake.lance.tiering.ShadedArrowBatchWriter; +import org.apache.fluss.row.GenericArray; +import org.apache.fluss.row.GenericRow; +import org.apache.fluss.types.DataTypes; +import org.apache.fluss.types.RowType; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link ArrowDataConverter#convertToNonShaded}. */ +class ArrowDataConverterTest { + + private org.apache.fluss.shaded.arrow.org.apache.arrow.memory.BufferAllocator shadedAllocator; + private BufferAllocator nonShadedAllocator; + + @BeforeEach + void setUp() { + shadedAllocator = + new org.apache.fluss.shaded.arrow.org.apache.arrow.memory.RootAllocator( + Long.MAX_VALUE); + nonShadedAllocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + void tearDown() { + shadedAllocator.close(); + nonShadedAllocator.close(); + } + + @Test + void testConvertListToFixedSizeList() { + int listSize = 3; + RowType rowType = + DataTypes.ROW(DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + Map properties = new HashMap<>(); + properties.put("embedding.arrow.fixed-size-list.size", String.valueOf(listSize)); + Schema nonShadedSchema = LanceArrowUtils.toArrowSchema(rowType, properties); + + // Populate shaded root with 3 rows, each having a list of exactly 3 floats + float[][] data = { + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f}, + {7.0f, 8.0f, 9.0f} + }; + + try (ShadedArrowBatchWriter writer = new ShadedArrowBatchWriter(shadedAllocator, rowType)) { + for (float[] floats : data) { + GenericRow row = new GenericRow(1); + row.setField(0, new GenericArray(floats)); + writer.writeRow(row); + } + writer.finish(); + + try (VectorSchemaRoot nonShadedRoot = + ArrowDataConverter.convertToNonShaded( + writer.getShadedRoot(), nonShadedAllocator, nonShadedSchema)) { + assertThat(nonShadedRoot.getRowCount()).isEqualTo(3); + assertThat(nonShadedRoot.getVector("embedding")) + .isInstanceOf(FixedSizeListVector.class); + + FixedSizeListVector fixedSizeListVector = + (FixedSizeListVector) nonShadedRoot.getVector("embedding"); + assertThat(fixedSizeListVector.getListSize()).isEqualTo(listSize); + + for (int i = 0; i < data.length; i++) { + List values = fixedSizeListVector.getObject(i); + assertThat(values).hasSize(listSize); + for (int j = 0; j < listSize; j++) { + assertThat((Float) values.get(j)).isEqualTo(data[i][j]); + } + } + } + } + } + + @Test + void testConvertListToFixedSizeListMismatchedCount() { + int listSize = 3; + RowType rowType = + DataTypes.ROW(DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + Map properties = new HashMap<>(); + properties.put("embedding.arrow.fixed-size-list.size", String.valueOf(listSize)); + Schema nonShadedSchema = LanceArrowUtils.toArrowSchema(rowType, properties); + + // Write 2 rows with 2 elements each: total child elements = 4, expected = 2*3 = 6 + try (ShadedArrowBatchWriter writer = new ShadedArrowBatchWriter(shadedAllocator, rowType)) { + GenericRow row1 = new GenericRow(1); + row1.setField(0, new GenericArray(new float[] {1.0f, 2.0f})); + writer.writeRow(row1); + + GenericRow row2 = new GenericRow(1); + row2.setField(0, new GenericArray(new float[] {3.0f, 4.0f})); + writer.writeRow(row2); + writer.finish(); + + assertThatThrownBy( + () -> + ArrowDataConverter.convertToNonShaded( + writer.getShadedRoot(), + nonShadedAllocator, + nonShadedSchema)) + .isInstanceOf(IllegalArgumentException.class); + + // Verify no memory leaked from the failed conversion + assertThat(nonShadedAllocator.getAllocatedMemory()).isZero(); + } + } +} diff --git a/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/utils/LanceArrowUtilsTest.java b/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/utils/LanceArrowUtilsTest.java new file mode 100644 index 0000000000..1319d6079a --- /dev/null +++ b/fluss-lake/fluss-lake-lance/src/test/java/org/apache/fluss/lake/lance/utils/LanceArrowUtilsTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.lake.lance.utils; + +import org.apache.fluss.types.DataTypes; +import org.apache.fluss.types.RowType; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link LanceArrowUtils#toArrowSchema(RowType, Map)}. */ +class LanceArrowUtilsTest { + + @Test + void testArrayColumnWithoutProperty() { + RowType rowType = + DataTypes.ROW(DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + Schema schema = LanceArrowUtils.toArrowSchema(rowType, Collections.emptyMap()); + + Field embeddingField = schema.findField("embedding"); + assertThat(embeddingField.getType()).isInstanceOf(ArrowType.List.class); + } + + @Test + void testArrayColumnWithFixedSizeListProperty() { + RowType rowType = + DataTypes.ROW(DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + Map properties = new HashMap<>(); + properties.put("embedding.arrow.fixed-size-list.size", "4"); + + Schema schema = LanceArrowUtils.toArrowSchema(rowType, properties); + + Field embeddingField = schema.findField("embedding"); + assertThat(embeddingField.getType()).isInstanceOf(ArrowType.FixedSizeList.class); + assertThat(((ArrowType.FixedSizeList) embeddingField.getType()).getListSize()).isEqualTo(4); + + // Child should still be a float element + assertThat(embeddingField.getChildren()).hasSize(1); + assertThat(embeddingField.getChildren().get(0).getName()).isEqualTo("element"); + } + + @Test + void testArrayColumnWithZeroSize() { + RowType rowType = + DataTypes.ROW(DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + Map properties = new HashMap<>(); + properties.put("embedding.arrow.fixed-size-list.size", "0"); + + assertThatThrownBy(() -> LanceArrowUtils.toArrowSchema(rowType, properties)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void testArrayColumnWithNegativeSize() { + RowType rowType = + DataTypes.ROW(DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + Map properties = new HashMap<>(); + properties.put("embedding.arrow.fixed-size-list.size", "-1"); + + assertThatThrownBy(() -> LanceArrowUtils.toArrowSchema(rowType, properties)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void testArrayColumnWithNonNumericSize() { + RowType rowType = + DataTypes.ROW(DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + Map properties = new HashMap<>(); + properties.put("embedding.arrow.fixed-size-list.size", "abc"); + + assertThatThrownBy(() -> LanceArrowUtils.toArrowSchema(rowType, properties)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void testToArrowSchemaWithEmptyProperties() { + RowType rowType = + DataTypes.ROW(DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + Schema schema = LanceArrowUtils.toArrowSchema(rowType, Collections.emptyMap()); + + Field embeddingField = schema.findField("embedding"); + assertThat(embeddingField.getType()).isInstanceOf(ArrowType.List.class); + } + + @Test + void testToArrowSchemaWithNullProperties() { + RowType rowType = + DataTypes.ROW(DataTypes.FIELD("embedding", DataTypes.ARRAY(DataTypes.FLOAT()))); + + Schema schema = LanceArrowUtils.toArrowSchema(rowType, null); + + Field embeddingField = schema.findField("embedding"); + assertThat(embeddingField.getType()).isInstanceOf(ArrowType.List.class); + } +}