From bd30a21e0816ff941cdfb5cbd8fa820101d2d316 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 13:52:10 -0400 Subject: [PATCH 1/7] feat(core): add LambdaBuilder for build-time validation of lambda parameter references Introduces LambdaBuilder, a context-aware builder that maintains a lambda parameter stack (lambdaContext) to validate parameter references at build time. Nested lambdas use the same builder, ensuring stepsOut is computed automatically. Mirrors the lambdaContext pattern from substrait-go. --- .../io/substrait/expression/Expression.java | 4 - .../substrait/expression/LambdaBuilder.java | 99 ++++ .../proto/ProtoExpressionConverter.java | 4 +- .../ExpressionCopyOnWriteVisitor.java | 6 +- .../expression/LambdaBuilderTest.java | 44 ++ .../proto/LambdaExpressionRoundtripTest.java | 473 ++++++++---------- .../expression/RexExpressionConverter.java | 3 +- .../isthmus/LambdaExpressionTest.java | 127 +---- 8 files changed, 377 insertions(+), 383 deletions(-) create mode 100644 core/src/main/java/io/substrait/expression/LambdaBuilder.java create mode 100644 core/src/test/java/io/substrait/expression/LambdaBuilderTest.java diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 72510ad2c..b116cbdc2 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -775,10 +775,6 @@ public Type getType() { return Type.withNullability(false).func(paramTypes, returnType); } - public static ImmutableExpression.Lambda.Builder builder() { - return ImmutableExpression.Lambda.builder(); - } - @Override public R accept( ExpressionVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java new file mode 100644 index 000000000..b2b89c02d --- /dev/null +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -0,0 +1,99 @@ +package io.substrait.expression; + +import io.substrait.type.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * Builds lambda expressions with build-time validation of parameter references. + * + *

Maintains a stack of lambda parameter scopes. Each call to {@link #lambda} pushes parameters + * onto the stack, builds the body via a callback, and pops. Nested lambdas simply call {@code + * lambda()} again on the same builder. + * + *

The callback receives a {@link Scope} handle for creating validated parameter references. The + * correct {@code stepsOut} value is computed automatically from the stack. + * + *

{@code
+ * LambdaBuilder lb = new LambdaBuilder();
+ *
+ * // Simple: (x: i32) -> x
+ * Expression.Lambda simple = lb.lambda(List.of(R.I32), x -> x.ref(0));
+ *
+ * // Nested: (x: i32) -> (y: i64) -> add(x, y)
+ * Expression.Lambda nested = lb.lambda(List.of(R.I32), x ->
+ *     lb.lambda(List.of(R.I64), y ->
+ *         add(x.ref(0), y.ref(0))
+ *     )
+ * );
+ * }
+ */ +public class LambdaBuilder { + private final List lambdaContext = new ArrayList<>(); + + /** + * Builds a lambda expression. The body function receives a {@link Scope} for creating validated + * parameter references. Nested lambdas are built by calling this method again inside the + * callback. + * + * @param paramTypes the lambda's parameter types + * @param bodyFn function that builds the lambda body given a scope handle + * @return the constructed lambda expression + */ + public Expression.Lambda lambda(List paramTypes, Function bodyFn) { + Type.Struct params = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + pushLambdaContext(params); + try { + int index = lambdaContext.size() - 1; + Scope scope = new Scope(index); + Expression body = bodyFn.apply(scope); + return ImmutableExpression.Lambda.builder().parameters(params).body(body).build(); + } finally { + popLambdaContext(); + } + } + + /** + * Pushes a lambda's parameters onto the context stack. This makes the parameters available for + * validation when building the lambda's body, and allows nested lambda parameter references to + * correctly compute their stepsOut values. + */ + private void pushLambdaContext(Type.Struct params) { + lambdaContext.add(params); + } + + /** + * Pops the most recently pushed lambda parameters from the context stack. Called after a lambda's + * body has been built, restoring the context to the enclosing lambda's scope. + */ + private void popLambdaContext() { + lambdaContext.remove(lambdaContext.size() - 1); + } + + /** + * A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated + * parameter references. + */ + public class Scope { + private final int index; + + private Scope(int index) { + this.index = index; + } + + /** + * Creates a validated reference to a parameter of this lambda. The correct {@code stepsOut} + * value is computed automatically. + * + * @param paramIndex index of the parameter within this lambda's parameter struct + * @return a {@link FieldReference} pointing to the specified parameter + * @throws IndexOutOfBoundsException if paramIndex is out of bounds + */ + public FieldReference ref(int paramIndex) { + int stepsOut = lambdaContext.size() - 1 - index; + return FieldReference.newLambdaParameterReference( + paramIndex, lambdaContext.get(index), stepsOut); + } + } +} diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 290e88c49..426b0f56a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -6,6 +6,7 @@ import io.substrait.expression.FieldReference.ReferenceSegment; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; +import io.substrait.expression.ImmutableExpression; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; @@ -282,6 +283,7 @@ public Type visit(Type.Struct type) throws RuntimeException { case LAMBDA: { + // TODO: Add build-time validation of lambda parameter references during deserialization. io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); Type.Struct parameters = (Type.Struct) @@ -299,7 +301,7 @@ public Type visit(Type.Struct type) throws RuntimeException { lambdaParameterStack.pop(); } - return Expression.Lambda.builder().parameters(parameters).body(body).build(); + return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); } // TODO enum. case ENUM: diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 11fc14f27..b097614a0 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -8,6 +8,7 @@ import io.substrait.expression.ExpressionVisitor; import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; +import io.substrait.expression.ImmutableExpression; import io.substrait.util.EmptyVisitationContext; import java.util.List; import java.util.Optional; @@ -448,7 +449,10 @@ public Optional visit(Expression.Lambda lambda, EmptyVisitationConte return Optional.empty(); } return Optional.of( - Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build()); + ImmutableExpression.Lambda.builder() + .from(lambda) + .body(newBody.orElse(lambda.body())) + .build()); } // utilities diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java new file mode 100644 index 000000000..7d8f8976b --- /dev/null +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -0,0 +1,44 @@ +package io.substrait.expression; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** Tests for {@link LambdaBuilder} build-time validation. */ +class LambdaBuilderTest { + + static final TypeCreator R = TypeCreator.REQUIRED; + + final LambdaBuilder lb = new LambdaBuilder(); + + // (x: i32) -> x[5] — field index 5 is out of bounds (only 1 param) + @Test + void invalidFieldIndex_outOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(5))); + } + + // (x: i32) -> x[-1] — negative field index + @Test + void negativeFieldIndex() { + assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1))); + } + + // (x: i32) -> (y: i64) -> x[5] — outer field index 5 is out of bounds + @Test + void nestedOuterFieldIndexOutOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, + () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(5)))); + } + + // (x: i32) -> (y: i64) -> y[3] — inner field index 3 is out of bounds (only 1 param) + @Test + void nestedInnerFieldIndexOutOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, + () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(3)))); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java index c0a979a94..3b9cfedad 100644 --- a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -1,361 +1,298 @@ package io.substrait.type.proto; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; +import io.substrait.expression.LambdaBuilder; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.type.Type; import java.util.List; import org.junit.jupiter.api.Test; -/** - * Tests for Lambda expression round-trip conversion through protobuf. Based on equivalent tests - * from substrait-go. - */ class LambdaExpressionRoundtripTest extends TestBase { - /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ - @Test - void zeroParameterLambda() { - Type.Struct emptyParams = Type.Struct.builder().nullable(false).build(); + final LambdaBuilder lb = new LambdaBuilder(); - Expression body = ExpressionCreator.i32(false, 42); + // ==================== Single Lambda Tests ==================== - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(emptyParams).body(body).build(); + // () -> 42 + @Test + void zeroParameterLambda() { + Expression.Lambda lambda = lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42)); verifyRoundTrip(lambda); - // Verify the lambda type - Type lambdaType = lambda.getType(); - assertInstanceOf(Type.Func.class, lambdaType); - Type.Func funcType = (Type.Func) lambdaType; + Type.Func funcType = (Type.Func) lambda.getType(); assertEquals(0, funcType.parameterTypes().size()); assertEquals(R.I32, funcType.returnType()); } - /** Test valid stepsOut=0 references. Building: ($0: i32) -> $0 : func i32> */ + // (x: i32) -> x @Test - void validStepsOut0() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Lambda body references parameter 0 with stepsOut=0 - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); + void identityLambda() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); verifyRoundTrip(lambda); - // Verify types - Type lambdaType = lambda.getType(); - assertInstanceOf(Type.Func.class, lambdaType); - Type.Func funcType = (Type.Func) lambdaType; + Type.Func funcType = (Type.Func) lambda.getType(); assertEquals(1, funcType.parameterTypes().size()); assertEquals(R.I32, funcType.parameterTypes().get(0)); assertEquals(R.I32, funcType.returnType()); + + assertInstanceOf(FieldReference.class, lambda.body()); + FieldReference ref = (FieldReference) lambda.body(); + assertTrue(ref.isLambdaParameterReference()); + assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); } - /** - * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 - * : func<(i32, i64, string) -> string> - */ + // (x: i32, y: i64, z: string) -> z @Test void validFieldIndex() { - Type.Struct params = - Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(2)); - // Reference the 3rd parameter (string) - FieldReference paramRef = FieldReference.newLambdaParameterReference(2, params, 0); + verifyRoundTrip(lambda); + assertEquals(R.STRING, ((Type.Func) lambda.getType()).returnType()); + } + // (x: i32) -> 42 + @Test + void lambdaWithLiteralBody() { Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); + lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42)); verifyRoundTrip(lambda); - - // Verify return type is string - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(R.STRING, funcType.returnType()); + assertInstanceOf(Expression.I32Literal.class, lambda.body()); } - /** Test type resolution for different parameter types. */ + // Parameterized: (params...) -> params[fieldIndex], verifying type resolution @Test void typeResolution() { - // Test cases: (paramTypes, fieldIndex, expectedReturnType) - record TestCase(List paramTypes, int fieldIndex, Type expectedType) {} + record TestCase(String name, List paramTypes, int fieldIndex, Type expectedType) {} List testCases = List.of( - new TestCase(List.of(R.I32), 0, R.I32), - new TestCase(List.of(R.I32, R.I64), 1, R.I64), - new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), - new TestCase(List.of(R.FP64), 0, R.FP64), - new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); + new TestCase("first param (i32)", List.of(R.I32), 0, R.I32), + new TestCase("second param (i64)", List.of(R.I32, R.I64), 1, R.I64), + new TestCase("third param (string)", List.of(R.I32, R.I64, R.STRING), 2, R.STRING), + new TestCase("float64 param", List.of(R.FP64), 0, R.FP64), + new TestCase("date param", List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); for (TestCase tc : testCases) { - Type.Struct params = - Type.Struct.builder().nullable(false).addAllFields(tc.paramTypes).build(); - - FieldReference paramRef = - FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); + Expression.Lambda lambda = lb.lambda(tc.paramTypes, params -> params.ref(tc.fieldIndex)); verifyRoundTrip(lambda); - // Verify the body type matches expected - assertEquals( - tc.expectedType, - lambda.body().getType(), - "Body type should match referenced parameter type"); - - // Verify lambda return type + assertEquals(tc.expectedType, lambda.body().getType(), tc.name + ": body type mismatch"); Type.Func funcType = (Type.Func) lambda.getType(); assertEquals( - tc.expectedType, funcType.returnType(), "Lambda return type should match body type"); + tc.expectedType, funcType.returnType(), tc.name + ": lambda return type mismatch"); } } - /** - * Test nested lambda with outer reference. Building: ($0: i64, $1: i64) -> (($0: i32) -> - * outer[$0] : i64) : func<(i64, i64) -> func i64>> - */ + // (x: i32, y: string) -> y — verify full Func type structure @Test - void nestedLambdaWithOuterRef() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64, R.I64).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Inner lambda references outer's parameter 0 with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); - - verifyRoundTrip(outerLambda); + void lambdaGetTypeReturnsFunc() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.STRING), params -> params.ref(1)); - // Verify structure - assertInstanceOf(Expression.Lambda.class, outerLambda.body()); - Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); - assertEquals(1, resultInner.parameters().fields().size()); + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(2, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.STRING, funcType.parameterTypes().get(1)); + assertEquals(R.STRING, funcType.returnType()); } - /** - * Test outer reference type resolution in nested lambdas. Building: ($0: i32, $1: i64, $2: - * string) -> (($0: fp64) -> outer[$2] : string) : func<...> - */ + // (x: i32, y: i64, z: string) -> ... — verify FieldReference metadata for each param @Test - void outerRefTypeResolution() { - Type.Struct outerParams = - Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.FP64).build(); - - // Inner references outer's field 2 (string) with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(2, outerParams, 1); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); - - verifyRoundTrip(outerLambda); - - // Verify inner lambda's return type is string (from outer param 2) - Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); - Type.Func innerFuncType = (Type.Func) resultInner.getType(); - assertEquals( - R.STRING, - innerFuncType.returnType(), - "Inner lambda return type should be string from outer.$2"); - - // Verify body's type is also string - assertEquals(R.STRING, resultInner.body().getType(), "Body type should be string"); + void parameterReferenceMetadata() { + List paramTypes = List.of(R.I32, R.I64, R.STRING); + + lb.lambda( + paramTypes, + params -> { + for (int i = 0; i < 3; i++) { + FieldReference ref = params.ref(i); + assertTrue(ref.isLambdaParameterReference()); + assertFalse(ref.isOuterReference()); + assertFalse(ref.isSimpleRootReference()); + assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); + assertEquals(paramTypes.get(i), ref.getType()); + } + return params.ref(0); + }); } - /** - * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func - * i64> - */ - @Test - void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + // ==================== Expression Body Tests ==================== - Expression.Cast castExpr = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(castExpr).build(); + // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 + @Test + void nestedLambdaWithArithmeticBody() { + String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; + + Expression.Lambda result = + lb.lambda( + List.of(R.I64), + outer -> + lb.lambda( + List.of(R.I64, R.I64), + inner -> { + // y1 * x + Expression multiply = + sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); + // (y1 * x) + y2 + return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); + })); + + verifyRoundTrip(result); + + // Outer lambda returns a func type + Type.Func outerFuncType = (Type.Func) result.getType(); + assertInstanceOf(Type.Func.class, outerFuncType.returnType()); + + // Inner lambda returns i64 + Expression.Lambda innerLambda = (Expression.Lambda) result.body(); + Type.Func innerFuncType = (Type.Func) innerLambda.getType(); + assertEquals(R.I64, innerFuncType.returnType()); + + // Inner body is a scalar function (add) + assertInstanceOf(Expression.ScalarFunctionInvocation.class, innerLambda.body()); + } - verifyRoundTrip(lambda); + // ==================== Nested Lambda Tests ==================== - // Verify the nested FieldRef has its type resolved - Expression.Cast resultCast = (Expression.Cast) lambda.body(); - assertInstanceOf(FieldReference.class, resultCast.input()); - FieldReference resultFieldRef = (FieldReference) resultCast.input(); + // (x: i64, y: i64) -> (z: i32) -> x + @Test + void nestedLambdaWithOuterRef() { + Expression.Lambda result = + lb.lambda(List.of(R.I64, R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); - assertNotNull(resultFieldRef.getType(), "Nested FieldRef should have type resolved"); - assertEquals(R.I32, resultFieldRef.getType(), "Should resolve to i32"); + verifyRoundTrip(result); - // Verify lambda return type is i64 (cast output) - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(R.I64, funcType.returnType()); + Expression.Lambda resultInner = (Expression.Lambda) result.body(); + assertEquals(1, resultInner.parameters().fields().size()); + assertEquals(R.I64, resultInner.body().getType()); } - /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 - * as i64) as string) : func string> - */ + // (x: i32, y: i64, z: string) -> (w: fp64) -> z @Test - void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = - (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(outerCast).build(); + void nestedLambdaOuterRefTypeResolution() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32, R.I64, R.STRING), + outer -> lb.lambda(List.of(R.FP64), inner -> outer.ref(2))); - verifyRoundTrip(lambda); - - // Navigate to the deeply nested FieldRef (2 levels deep) - Expression.Cast resultOuter = (Expression.Cast) lambda.body(); - Expression.Cast resultInner = (Expression.Cast) resultOuter.input(); - FieldReference resultFieldRef = (FieldReference) resultInner.input(); + verifyRoundTrip(result); - // Verify type is resolved even at depth 2 - assertNotNull(resultFieldRef.getType(), "FieldRef at depth 2 should have type resolved"); - assertEquals(R.I32, resultFieldRef.getType()); + Expression.Lambda resultInner = (Expression.Lambda) result.body(); + assertEquals(R.STRING, ((Type.Func) resultInner.getType()).returnType()); } - /** - * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> - */ + // (x: i32) -> (y: i64) -> y @Test - void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + void nestedLambdaInnerRefOnly() { + Expression.Lambda result = + lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(0))); - Expression body = ExpressionCreator.i32(false, 42); + verifyRoundTrip(result); - Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); - - verifyRoundTrip(lambda); + Expression.Lambda innerLambda = (Expression.Lambda) result.body(); + assertEquals(R.I64, innerLambda.body().getType()); + assertInstanceOf(Type.Func.class, ((Type.Func) result.getType()).returnType()); } - /** Test lambda getType returns correct Func type. */ + // (x: i32) -> (y: i64) -> (x, y) — body references both outer and inner params @Test - void lambdaGetTypeReturnsFunc() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32, R.STRING).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(1, params, 0); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); - - Type lambdaType = lambda.getType(); - - assertInstanceOf(Type.Func.class, lambdaType); - Type.Func funcType = (Type.Func) lambdaType; - - assertEquals(2, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.parameterTypes().get(0)); - assertEquals(R.STRING, funcType.parameterTypes().get(1)); - assertEquals(R.STRING, funcType.returnType()); // body references param 1 which is STRING + void nestedLambdaBothInnerAndOuterRefs() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32), + outer -> + lb.lambda( + List.of(R.I64), + inner -> { + FieldReference innerRef = inner.ref(0); + assertEquals(R.I64, innerRef.getType()); + assertEquals(0, innerRef.lambdaParameterReferenceStepsOut().orElse(-1)); + + FieldReference outerRef = outer.ref(0); + assertEquals(R.I32, outerRef.getType()); + assertEquals(1, outerRef.lambdaParameterReferenceStepsOut().orElse(-1)); + + return innerRef; + })); + + verifyRoundTrip(result); } - // ==================== Validation Error Tests ==================== - - /** - * Test that invalid outer reference (stepsOut too high) fails during proto conversion. Building: - * ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) - */ + // (a: i32, b: string) -> (c: i64, d: fp64) -> b — verify all 4 params resolve correctly @Test - void invalidOuterRef_stepsOutTooHigh() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Create a parameter reference with stepsOut=1 but no outer lambda exists - FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, params, 1); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(invalidRef).build(); - - // Convert to proto - this should work - io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(lambda); - - // Converting back should fail because stepsOut=1 references non-existent outer lambda - assertThrows( - IllegalArgumentException.class, - () -> { - protoExpressionConverter.from(protoExpression); - }, - "Should fail when stepsOut references non-existent outer lambda"); + void nestedLambdaMultiParamCorrectResolution() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32, R.STRING), + outer -> + lb.lambda( + List.of(R.I64, R.FP64), + inner -> { + assertEquals(R.I64, inner.ref(0).getType()); + assertEquals(R.FP64, inner.ref(1).getType()); + assertEquals(R.I32, outer.ref(0).getType()); + assertEquals(R.STRING, outer.ref(1).getType()); + + return outer.ref(1); + })); + + verifyRoundTrip(result); + + Expression.Lambda innerLambda = (Expression.Lambda) result.body(); + assertEquals(R.STRING, ((Type.Func) innerLambda.getType()).returnType()); } - /** - * Test that invalid field index (out of bounds) fails during proto conversion. Building: ($0: - * i32) -> $5 : INVALID (only has 1 param) - */ + // (x: i32) -> (y: i64) -> (z: string) -> x @Test - void invalidFieldIndex_outOfBounds() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Create a reference to field 5, but lambda only has 1 parameter (index 0) - // This will fail at build time since newLambdaParameterReference accesses fields.get(5) - assertThrows( - IndexOutOfBoundsException.class, - () -> { - FieldReference.newLambdaParameterReference(5, params, 0); - }, - "Should fail when field index is out of bounds"); + void tripleNestedLambdaRoundtrip() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32), + outer -> + lb.lambda( + List.of(R.I64), mid -> lb.lambda(List.of(R.STRING), inner -> outer.ref(0)))); + + verifyRoundTrip(result); + + Expression.Lambda l1 = (Expression.Lambda) result.body(); + Expression.Lambda l2 = (Expression.Lambda) l1.body(); + assertEquals(R.I32, l2.body().getType()); } - /** - * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). Building: ($0: i64) -> - * (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) - */ + // (x: i32) -> (y: i64) -> (z: string) -> ... — verify stepsOut is auto-computed at each level @Test - void nestedInvalidOuterRef() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Inner lambda references stepsOut=2, but only 1 outer lambda exists - FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, outerParams, 2); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(invalidRef).build(); - - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); - - // Convert to proto - io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(outerLambda); - - // Converting back should fail because stepsOut=2 references non-existent grandparent - assertThrows( - IllegalArgumentException.class, - () -> { - protoExpressionConverter.from(protoExpression); - }, - "Should fail when stepsOut references non-existent grandparent lambda"); + void tripleNestedLambdaScopeTracking() { + lb.lambda( + List.of(R.I32), + outer -> + lb.lambda( + List.of(R.I64), + mid -> + lb.lambda( + List.of(R.STRING), + inner -> { + assertEquals(R.STRING, inner.ref(0).getType()); + assertEquals(R.I64, mid.ref(0).getType()); + assertEquals(R.I32, outer.ref(0).getType()); + + assertEquals( + 0, inner.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); + assertEquals(1, mid.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); + assertEquals( + 2, outer.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); + + return inner.ref(0); + }))); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 176b246e7..d5330681d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableExpression; import io.substrait.isthmus.CallConverter; import io.substrait.isthmus.SubstraitRelVisitor; import io.substrait.isthmus.TypeConverter; @@ -212,7 +213,7 @@ public Expression visitLambda(RexLambda rexLambda) { Expression body = rexLambda.getExpression().accept(this); - return Expression.Lambda.builder().parameters(parameters).body(body).build(); + return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index add746ee1..55924081d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -4,146 +4,57 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; +import io.substrait.expression.LambdaBuilder; import io.substrait.relation.Project; import io.substrait.relation.Rel; -import io.substrait.type.Type; import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.Test; -/** - * Tests for Lambda expression conversion between Substrait and Calcite. Note: Calcite does not - * support nested lambda expressions for the moment, so all tests use stepsOut=0. - */ class LambdaExpressionTest extends PlanTestBase { final Rel emptyTable = sb.emptyVirtualTableScan(); + final LambdaBuilder lb = new LambdaBuilder(); - /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ + // () -> 42 @Test void lambdaExpressionZeroParameters() { - Type.Struct params = Type.Struct.builder().nullable(false).build(); + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42))); - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); - expressionList.add(lambda); - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertFullRoundTrip(project); } - /** - * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 - * : func<(i32, i64, string) -> string> - */ + // (x: i32, y: i64, z: string) -> x @Test void validFieldIndex() { - Type.Struct params = - Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); - - expressionList.add(lambda); - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func - * i64> - */ - @Test - void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - Expression.Cast castExpr = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(castExpr).build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 - * as i64) as string) : func string> - */ - @Test - void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = - (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(0))); - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(outerCast).build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertFullRoundTrip(project); } - /** - * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> - */ + // (x: i32) -> 42 @Test void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42))); - Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertFullRoundTrip(project); } - /** - * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. Calcite does not - * support nested lambda expressions. - */ + // (x: i64) -> (y: i32) -> x — Calcite doesn't support nested lambdas @Test void nestedLambdaThrowsUnsupportedOperation() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Inner lambda references outer's parameter with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - - List expressionList = new ArrayList<>(); - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + lb.lambda(List.of(R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); - expressionList.add(outerLambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + List exprs = new ArrayList<>(); + exprs.add(outerLambda); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); } } From 50db6998d84c3ee9af532e5ca2073e14e727287d Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:28:21 -0400 Subject: [PATCH 2/7] refactor: unify lambda validation, add JSON-based roundtrip tests Moves ProtoExpressionConverter to use LambdaBuilder for lambda parameter validation, removing the private LambdaParameterStack. Replaces builder-based roundtrip tests with parameterized JSON fixtures in expressions/lambda/. Adds arithmetic body test to isthmus LambdaExpressionTest. --- .../substrait/expression/LambdaBuilder.java | 39 +++ .../proto/ProtoExpressionConverter.java | 56 +--- .../expression/LambdaBuilderTest.java | 51 ++- .../proto/LambdaExpressionRoundtripTest.java | 313 ++---------------- .../lambda/invalid/nested_steps_out.json | 30 ++ .../expressions/lambda/invalid/steps_out.json | 21 ++ .../expressions/lambda/valid/identity.json | 21 ++ .../lambda/valid/literal_body.json | 14 + .../expressions/lambda/valid/multi_param.json | 23 ++ .../expressions/lambda/valid/nested.json | 30 ++ .../lambda/valid/triple_nested.json | 39 +++ .../expressions/lambda/valid/zero_params.json | 12 + .../isthmus/LambdaExpressionTest.java | 33 ++ 13 files changed, 348 insertions(+), 334 deletions(-) create mode 100644 core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json create mode 100644 core/src/test/resources/expressions/lambda/invalid/steps_out.json create mode 100644 core/src/test/resources/expressions/lambda/valid/identity.json create mode 100644 core/src/test/resources/expressions/lambda/valid/literal_body.json create mode 100644 core/src/test/resources/expressions/lambda/valid/multi_param.json create mode 100644 core/src/test/resources/expressions/lambda/valid/nested.json create mode 100644 core/src/test/resources/expressions/lambda/valid/triple_nested.json create mode 100644 core/src/test/resources/expressions/lambda/valid/zero_params.json diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index b2b89c02d..1197cd1f1 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -54,6 +54,45 @@ public Expression.Lambda lambda(List paramTypes, Function bodyFn) { + pushLambdaContext(params); + try { + Expression body = bodyFn.get(); + return ImmutableExpression.Lambda.builder().parameters(params).body(body).build(); + } finally { + popLambdaContext(); + } + } + + /** + * Resolves the parameter struct for a lambda at the given stepsOut from the current innermost + * scope. Used by internal converters to validate lambda parameter references during + * deserialization. + * + * @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost) + * @return the parameter struct at the target scope level + * @throws IllegalArgumentException if stepsOut exceeds the current nesting depth + */ + public Type.Struct resolveParams(int stepsOut) { + int index = lambdaContext.size() - 1 - stepsOut; + if (index < 0 || index >= lambdaContext.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, lambdaContext.size())); + } + return lambdaContext.get(index); + } + /** * Pushes a lambda's parameters onto the context stack. This makes the parameters available for * validation when building the lambda's body, and allows nested lambda parameter references to diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 426b0f56a..470c13c6a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -6,7 +6,7 @@ import io.substrait.expression.FieldReference.ReferenceSegment; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; -import io.substrait.expression.ImmutableExpression; +import io.substrait.expression.LambdaBuilder; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; @@ -38,7 +38,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; - private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack(); + private final LambdaBuilder lambdaBuilder = new LambdaBuilder(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -83,7 +83,7 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getLambdaParameterReference(); int stepsOut = lambdaParamRef.getStepsOut(); - Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut); + Type.Struct lambdaParameters = lambdaBuilder.resolveParams(stepsOut); // Check for unsupported nested field access if (reference.getDirectReference().getStructField().hasChild()) { @@ -283,7 +283,6 @@ public Type visit(Type.Struct type) throws RuntimeException { case LAMBDA: { - // TODO: Add build-time validation of lambda parameter references during deserialization. io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); Type.Struct parameters = (Type.Struct) @@ -292,16 +291,7 @@ public Type visit(Type.Struct type) throws RuntimeException { .setStruct(protoLambda.getParameters()) .build()); - lambdaParameterStack.push(parameters); - - Expression body; - try { - body = from(protoLambda.getBody()); - } finally { - lambdaParameterStack.pop(); - } - - return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); + return lambdaBuilder.lambdaFromStruct(parameters, () -> from(protoLambda.getBody())); } // TODO enum. case ENUM: @@ -622,42 +612,4 @@ public Expression.SortField fromSortField(SortField s) { public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } - - /** - * A stack for tracking lambda parameter types during expression parsing. - * - *

When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack. - * Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference: - * - *

    - *
  • stepsOut=0 refers to the innermost (current) lambda - *
  • stepsOut=1 refers to the next enclosing lambda - *
  • stepsOut=N refers to N levels up - *
- */ - private static class LambdaParameterStack { - private final List stack = new ArrayList<>(); - - void push(Type.Struct parameters) { - stack.add(parameters); - } - - void pop() { - if (stack.isEmpty()) { - throw new IllegalArgumentException("Lambda parameter stack is empty"); - } - stack.remove(stack.size() - 1); - } - - Type.Struct get(int stepsOut) { - int index = stack.size() - 1 - stepsOut; - if (index < 0 || index >= stack.size()) { - throw new IllegalArgumentException( - String.format( - "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", - stepsOut, stack.size())); - } - return stack.get(index); - } - } } diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index 7d8f8976b..492c56f30 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -1,32 +1,73 @@ package io.substrait.expression; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.List; import org.junit.jupiter.api.Test; -/** Tests for {@link LambdaBuilder} build-time validation. */ +/** Tests for {@link LambdaBuilder}. */ class LambdaBuilderTest { static final TypeCreator R = TypeCreator.REQUIRED; final LambdaBuilder lb = new LambdaBuilder(); - // (x: i32) -> x[5] — field index 5 is out of bounds (only 1 param) + // (x: i32)@p -> p[0] + @Test + void simpleLambda() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); + + Expression.Lambda expected = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) + .body( + FieldReference.newLambdaParameterReference( + 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 0)) + .build(); + + assertEquals(expected, lambda); + } + + // (x: i32)@outer -> (y: i64)@inner -> outer[0] + @Test + void nestedLambda() { + Expression.Lambda lambda = + lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(0))); + + Expression.Lambda expectedInner = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I64).build()) + .body( + FieldReference.newLambdaParameterReference( + 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 1)) + .build(); + + Expression.Lambda expected = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) + .body(expectedInner) + .build(); + + assertEquals(expected, lambda); + } + + // (x: i32)@p -> p[5] — only 1 param, index 5 is out of bounds @Test void invalidFieldIndex_outOfBounds() { assertThrows( IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(5))); } - // (x: i32) -> x[-1] — negative field index + // (x: i32)@p -> p[-1] — negative index @Test void negativeFieldIndex() { assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1))); } - // (x: i32) -> (y: i64) -> x[5] — outer field index 5 is out of bounds + // (x: i32)@outer -> (y: i64)@inner -> outer[5] — outer only has 1 param @Test void nestedOuterFieldIndexOutOfBounds() { assertThrows( @@ -34,7 +75,7 @@ void nestedOuterFieldIndexOutOfBounds() { () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(5)))); } - // (x: i32) -> (y: i64) -> y[3] — inner field index 3 is out of bounds (only 1 param) + // (x: i32)@outer -> (y: i64)@inner -> inner[3] — inner only has 1 param @Test void nestedInnerFieldIndexOutOfBounds() { assertThrows( diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java index 3b9cfedad..1589c35ba 100644 --- a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -1,298 +1,57 @@ package io.substrait.type.proto; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; +import com.google.protobuf.util.JsonFormat; import io.substrait.TestBase; import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; -import io.substrait.expression.LambdaBuilder; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.type.Type; -import java.util.List; -import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; class LambdaExpressionRoundtripTest extends TestBase { - final LambdaBuilder lb = new LambdaBuilder(); - - // ==================== Single Lambda Tests ==================== - - // () -> 42 - @Test - void zeroParameterLambda() { - Expression.Lambda lambda = lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42)); - - verifyRoundTrip(lambda); - - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(0, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.returnType()); - } - - // (x: i32) -> x - @Test - void identityLambda() { - Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); - - verifyRoundTrip(lambda); - - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(1, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.parameterTypes().get(0)); - assertEquals(R.I32, funcType.returnType()); - - assertInstanceOf(FieldReference.class, lambda.body()); - FieldReference ref = (FieldReference) lambda.body(); - assertTrue(ref.isLambdaParameterReference()); - assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); - } - - // (x: i32, y: i64, z: string) -> z - @Test - void validFieldIndex() { - Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(2)); - - verifyRoundTrip(lambda); - assertEquals(R.STRING, ((Type.Func) lambda.getType()).returnType()); - } - - // (x: i32) -> 42 - @Test - void lambdaWithLiteralBody() { - Expression.Lambda lambda = - lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42)); - - verifyRoundTrip(lambda); - assertInstanceOf(Expression.I32Literal.class, lambda.body()); - } - - // Parameterized: (params...) -> params[fieldIndex], verifying type resolution - @Test - void typeResolution() { - record TestCase(String name, List paramTypes, int fieldIndex, Type expectedType) {} - - List testCases = - List.of( - new TestCase("first param (i32)", List.of(R.I32), 0, R.I32), - new TestCase("second param (i64)", List.of(R.I32, R.I64), 1, R.I64), - new TestCase("third param (string)", List.of(R.I32, R.I64, R.STRING), 2, R.STRING), - new TestCase("float64 param", List.of(R.FP64), 0, R.FP64), - new TestCase("date param", List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); - - for (TestCase tc : testCases) { - Expression.Lambda lambda = lb.lambda(tc.paramTypes, params -> params.ref(tc.fieldIndex)); - - verifyRoundTrip(lambda); - - assertEquals(tc.expectedType, lambda.body().getType(), tc.name + ": body type mismatch"); - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals( - tc.expectedType, funcType.returnType(), tc.name + ": lambda return type mismatch"); - } - } - - // (x: i32, y: string) -> y — verify full Func type structure - @Test - void lambdaGetTypeReturnsFunc() { - Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.STRING), params -> params.ref(1)); - - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(2, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.parameterTypes().get(0)); - assertEquals(R.STRING, funcType.parameterTypes().get(1)); - assertEquals(R.STRING, funcType.returnType()); - } - - // (x: i32, y: i64, z: string) -> ... — verify FieldReference metadata for each param - @Test - void parameterReferenceMetadata() { - List paramTypes = List.of(R.I32, R.I64, R.STRING); - - lb.lambda( - paramTypes, - params -> { - for (int i = 0; i < 3; i++) { - FieldReference ref = params.ref(i); - assertTrue(ref.isLambdaParameterReference()); - assertFalse(ref.isOuterReference()); - assertFalse(ref.isSimpleRootReference()); - assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); - assertEquals(paramTypes.get(i), ref.getType()); - } - return params.ref(0); - }); - } - - // ==================== Expression Body Tests ==================== - - // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 - @Test - void nestedLambdaWithArithmeticBody() { - String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; - - Expression.Lambda result = - lb.lambda( - List.of(R.I64), - outer -> - lb.lambda( - List.of(R.I64, R.I64), - inner -> { - // y1 * x - Expression multiply = - sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); - // (y1 * x) + y2 - return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); - })); - - verifyRoundTrip(result); - - // Outer lambda returns a func type - Type.Func outerFuncType = (Type.Func) result.getType(); - assertInstanceOf(Type.Func.class, outerFuncType.returnType()); - - // Inner lambda returns i64 - Expression.Lambda innerLambda = (Expression.Lambda) result.body(); - Type.Func innerFuncType = (Type.Func) innerLambda.getType(); - assertEquals(R.I64, innerFuncType.returnType()); - - // Inner body is a scalar function (add) - assertInstanceOf(Expression.ScalarFunctionInvocation.class, innerLambda.body()); - } - - // ==================== Nested Lambda Tests ==================== - - // (x: i64, y: i64) -> (z: i32) -> x - @Test - void nestedLambdaWithOuterRef() { - Expression.Lambda result = - lb.lambda(List.of(R.I64, R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); - - verifyRoundTrip(result); - - Expression.Lambda resultInner = (Expression.Lambda) result.body(); - assertEquals(1, resultInner.parameters().fields().size()); - assertEquals(R.I64, resultInner.body().getType()); - } - - // (x: i32, y: i64, z: string) -> (w: fp64) -> z - @Test - void nestedLambdaOuterRefTypeResolution() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32, R.I64, R.STRING), - outer -> lb.lambda(List.of(R.FP64), inner -> outer.ref(2))); - - verifyRoundTrip(result); - - Expression.Lambda resultInner = (Expression.Lambda) result.body(); - assertEquals(R.STRING, ((Type.Func) resultInner.getType()).returnType()); + static Stream validLambdaExpressions() throws IOException { + return listJsonResources("expressions/lambda/valid"); } - // (x: i32) -> (y: i64) -> y - @Test - void nestedLambdaInnerRefOnly() { - Expression.Lambda result = - lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(0))); - - verifyRoundTrip(result); - - Expression.Lambda innerLambda = (Expression.Lambda) result.body(); - assertEquals(R.I64, innerLambda.body().getType()); - assertInstanceOf(Type.Func.class, ((Type.Func) result.getType()).returnType()); + static Stream invalidLambdaExpressions() throws IOException { + return listJsonResources("expressions/lambda/invalid"); } - // (x: i32) -> (y: i64) -> (x, y) — body references both outer and inner params - @Test - void nestedLambdaBothInnerAndOuterRefs() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32), - outer -> - lb.lambda( - List.of(R.I64), - inner -> { - FieldReference innerRef = inner.ref(0); - assertEquals(R.I64, innerRef.getType()); - assertEquals(0, innerRef.lambdaParameterReferenceStepsOut().orElse(-1)); - - FieldReference outerRef = outer.ref(0); - assertEquals(R.I32, outerRef.getType()); - assertEquals(1, outerRef.lambdaParameterReferenceStepsOut().orElse(-1)); - - return innerRef; - })); - - verifyRoundTrip(result); + @ParameterizedTest + @MethodSource("validLambdaExpressions") + void validLambdaExpressionRoundtrip(String resourcePath) throws IOException { + Expression deserialized = deserializeExpression(resourcePath); + assertInstanceOf(Expression.Lambda.class, deserialized); + verifyRoundTrip(deserialized); } - // (a: i32, b: string) -> (c: i64, d: fp64) -> b — verify all 4 params resolve correctly - @Test - void nestedLambdaMultiParamCorrectResolution() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32, R.STRING), - outer -> - lb.lambda( - List.of(R.I64, R.FP64), - inner -> { - assertEquals(R.I64, inner.ref(0).getType()); - assertEquals(R.FP64, inner.ref(1).getType()); - assertEquals(R.I32, outer.ref(0).getType()); - assertEquals(R.STRING, outer.ref(1).getType()); - - return outer.ref(1); - })); - - verifyRoundTrip(result); - - Expression.Lambda innerLambda = (Expression.Lambda) result.body(); - assertEquals(R.STRING, ((Type.Func) innerLambda.getType()).returnType()); + @ParameterizedTest + @MethodSource("invalidLambdaExpressions") + void invalidLambdaExpressionRejected(String resourcePath) { + assertThrows(Exception.class, () -> deserializeExpression(resourcePath)); } - // (x: i32) -> (y: i64) -> (z: string) -> x - @Test - void tripleNestedLambdaRoundtrip() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32), - outer -> - lb.lambda( - List.of(R.I64), mid -> lb.lambda(List.of(R.STRING), inner -> outer.ref(0)))); - - verifyRoundTrip(result); - - Expression.Lambda l1 = (Expression.Lambda) result.body(); - Expression.Lambda l2 = (Expression.Lambda) l1.body(); - assertEquals(R.I32, l2.body().getType()); + private static Stream listJsonResources(String dirPath) throws IOException { + Path dir = + Paths.get( + LambdaExpressionRoundtripTest.class.getClassLoader().getResource(dirPath).getPath()); + return Files.list(dir) + .filter(p -> p.toString().endsWith(".json")) + .map(p -> dirPath + "/" + p.getFileName().toString()) + .sorted(); } - // (x: i32) -> (y: i64) -> (z: string) -> ... — verify stepsOut is auto-computed at each level - @Test - void tripleNestedLambdaScopeTracking() { - lb.lambda( - List.of(R.I32), - outer -> - lb.lambda( - List.of(R.I64), - mid -> - lb.lambda( - List.of(R.STRING), - inner -> { - assertEquals(R.STRING, inner.ref(0).getType()); - assertEquals(R.I64, mid.ref(0).getType()); - assertEquals(R.I32, outer.ref(0).getType()); - - assertEquals( - 0, inner.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); - assertEquals(1, mid.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); - assertEquals( - 2, outer.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); - - return inner.ref(0); - }))); + private Expression deserializeExpression(String resourcePath) throws IOException { + String json = asString(resourcePath); + io.substrait.proto.Expression.Builder builder = io.substrait.proto.Expression.newBuilder(); + JsonFormat.parser().merge(json, builder); + return protoExpressionConverter.from(builder.build()); } } diff --git a/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json b/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json new file mode 100644 index 000000000..54ec7321b --- /dev/null +++ b/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json @@ -0,0 +1,30 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 2 + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/invalid/steps_out.json b/core/src/test/resources/expressions/lambda/invalid/steps_out.json new file mode 100644 index 000000000..148173e8b --- /dev/null +++ b/core/src/test/resources/expressions/lambda/invalid/steps_out.json @@ -0,0 +1,21 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 1 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/identity.json b/core/src/test/resources/expressions/lambda/valid/identity.json new file mode 100644 index 000000000..a984334ef --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/identity.json @@ -0,0 +1,21 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/literal_body.json b/core/src/test/resources/expressions/lambda/valid/literal_body.json new file mode 100644 index 000000000..c36e2df5e --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/literal_body.json @@ -0,0 +1,14 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "literal": { + "i32": 42 + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/multi_param.json b/core/src/test/resources/expressions/lambda/valid/multi_param.json new file mode 100644 index 000000000..1c31bc1e7 --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/multi_param.json @@ -0,0 +1,23 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } }, + { "i64": { "nullability": "NULLABILITY_REQUIRED" } }, + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/nested.json b/core/src/test/resources/expressions/lambda/valid/nested.json new file mode 100644 index 000000000..abf716b7a --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/nested.json @@ -0,0 +1,30 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 1 + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/triple_nested.json b/core/src/test/resources/expressions/lambda/valid/triple_nested.json new file mode 100644 index 000000000..e1e692521 --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/triple_nested.json @@ -0,0 +1,39 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 2 + } + } + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/zero_params.json b/core/src/test/resources/expressions/lambda/valid/zero_params.json new file mode 100644 index 000000000..df801c9cd --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/zero_params.json @@ -0,0 +1,12 @@ +{ + "lambda": { + "parameters": { + "types": [] + }, + "body": { + "literal": { + "i32": 42 + } + } + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index 55924081d..0c294ff9f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -1,10 +1,12 @@ package io.substrait.isthmus; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.LambdaBuilder; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.relation.Project; import io.substrait.relation.Rel; import java.util.ArrayList; @@ -57,4 +59,35 @@ void nestedLambdaThrowsUnsupportedOperation() { Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); } + + // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 + @Test + void nestedLambdaWithArithmeticBody() { + String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; + + Expression.Lambda lambda = + lb.lambda( + List.of(R.I64), + outer -> + lb.lambda( + List.of(R.I64, R.I64), + inner -> { + Expression multiply = + sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); + return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); + })); + + // Proto-only roundtrip since Calcite doesn't support nested lambdas + List exprs = new ArrayList<>(); + exprs.add(lambda); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); + + io.substrait.extension.ExtensionCollector collector = + new io.substrait.extension.ExtensionCollector(); + io.substrait.proto.Rel proto = + new io.substrait.relation.RelProtoConverter(collector).toProto(project); + io.substrait.relation.Rel roundTripped = + new io.substrait.relation.ProtoRelConverter(collector, extensions).from(proto); + assertEquals(project, roundTripped); + } } From f172704dce0b71de4f971c55d3f7a653a6d66b5d Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:35:59 -0400 Subject: [PATCH 3/7] docs: fix LambdaBuilder javadoc to use params/outer/inner naming --- .../main/java/io/substrait/expression/LambdaBuilder.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index 1197cd1f1..621a5ef90 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -19,12 +19,12 @@ * LambdaBuilder lb = new LambdaBuilder(); * * // Simple: (x: i32) -> x - * Expression.Lambda simple = lb.lambda(List.of(R.I32), x -> x.ref(0)); + * Expression.Lambda simple = lb.lambda(List.of(R.I32), params -> params.ref(0)); * * // Nested: (x: i32) -> (y: i64) -> add(x, y) - * Expression.Lambda nested = lb.lambda(List.of(R.I32), x -> - * lb.lambda(List.of(R.I64), y -> - * add(x.ref(0), y.ref(0)) + * Expression.Lambda nested = lb.lambda(List.of(R.I32), outer -> + * lb.lambda(List.of(R.I64), inner -> + * add(outer.ref(0), inner.ref(0)) * ) * ); * } From 67c5a8a763af262653cfd72a662b3e9841ab3339 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:45:12 -0400 Subject: [PATCH 4/7] refactor: clarify Scope internals, extract stepsOut() method and document depth-capture mechanism --- .../substrait/expression/LambdaBuilder.java | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index 621a5ef90..8709b6aa6 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -45,8 +45,7 @@ public Expression.Lambda lambda(List paramTypes, Function= lambdaContext.size()) { + int targetDepth = lambdaContext.size() - stepsOut; + if (targetDepth <= 0 || targetDepth > lambdaContext.size()) { throw new IllegalArgumentException( String.format( "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", stepsOut, lambdaContext.size())); } - return lambdaContext.get(index); + return lambdaContext.get(targetDepth - 1); } /** @@ -113,26 +112,39 @@ private void popLambdaContext() { /** * A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated * parameter references. + * + *

Each Scope captures the depth of the lambdaContext stack at the time it was created. When + * {@link #ref} is called, the Substrait {@code stepsOut} value is computed as the difference + * between the current stack depth and the captured depth. This means the same Scope produces + * different stepsOut values depending on the nesting level at the time of the call, which is what + * allows outer.ref(0) to produce stepsOut=1 when called inside a nested lambda. */ public class Scope { - private final int index; + private final Type.Struct params; + private final int depth; + + private Scope(Type.Struct params) { + this.params = params; + this.depth = lambdaContext.size(); + } - private Scope(int index) { - this.index = index; + /** + * Computes the number of lambda boundaries between this scope and the current innermost scope. + * This value changes dynamically as nested lambdas are built. + */ + private int stepsOut() { + return lambdaContext.size() - depth; } /** - * Creates a validated reference to a parameter of this lambda. The correct {@code stepsOut} - * value is computed automatically. + * Creates a validated reference to a parameter of this lambda. * * @param paramIndex index of the parameter within this lambda's parameter struct * @return a {@link FieldReference} pointing to the specified parameter * @throws IndexOutOfBoundsException if paramIndex is out of bounds */ public FieldReference ref(int paramIndex) { - int stepsOut = lambdaContext.size() - 1 - index; - return FieldReference.newLambdaParameterReference( - paramIndex, lambdaContext.get(index), stepsOut); + return FieldReference.newLambdaParameterReference(paramIndex, params, stepsOut()); } } } From eed9ea92cf79243a54f14486ea398f4e0a9daf6c Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:46:49 -0400 Subject: [PATCH 5/7] test: add test verifying stepsOut changes dynamically with nesting depth --- .../expression/LambdaBuilderTest.java | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index 492c56f30..a4c569219 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -54,6 +54,31 @@ void nestedLambda() { assertEquals(expected, lambda); } + // Verify that the same scope handle produces different stepsOut values depending on nesting. + // outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda. + @Test + void scopeStepsOutChangesDynamically() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + + lb.lambda( + List.of(R.I32), + outer -> { + FieldReference atTopLevel = outer.ref(0); + assertEquals(0, atTopLevel.lambdaParameterReferenceStepsOut().orElse(-1)); + + lb.lambda( + List.of(R.I64), + inner -> { + FieldReference atNestedLevel = outer.ref(0); + assertEquals(1, atNestedLevel.lambdaParameterReferenceStepsOut().orElse(-1)); + return inner.ref(0); + }); + + return atTopLevel; + }); + } + // (x: i32)@p -> p[5] — only 1 param, index 5 is out of bounds @Test void invalidFieldIndex_outOfBounds() { From c37d527a2fea7eedbdd00af258fa78f7ef64f9dc Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:54:32 -0400 Subject: [PATCH 6/7] test: simplify arithmetic body test to single lambda (x -> x + x) --- .../isthmus/LambdaExpressionTest.java | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index 0c294ff9f..fb33407b8 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -1,6 +1,5 @@ package io.substrait.isthmus; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.expression.Expression; @@ -60,34 +59,19 @@ void nestedLambdaThrowsUnsupportedOperation() { assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); } - // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 + // (x: i64)@p -> add(p[0], p[0]) @Test - void nestedLambdaWithArithmeticBody() { + void lambdaWithArithmeticBody() { String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; Expression.Lambda lambda = lb.lambda( List.of(R.I64), - outer -> - lb.lambda( - List.of(R.I64, R.I64), - inner -> { - Expression multiply = - sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); - return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); - })); + params -> sb.scalarFn(ARITH, "add:i64_i64", R.I64, params.ref(0), params.ref(0))); - // Proto-only roundtrip since Calcite doesn't support nested lambdas List exprs = new ArrayList<>(); exprs.add(lambda); Project project = Project.builder().expressions(exprs).input(emptyTable).build(); - - io.substrait.extension.ExtensionCollector collector = - new io.substrait.extension.ExtensionCollector(); - io.substrait.proto.Rel proto = - new io.substrait.relation.RelProtoConverter(collector).toProto(project); - io.substrait.relation.Rel roundTripped = - new io.substrait.relation.ProtoRelConverter(collector, extensions).from(proto); - assertEquals(project, roundTripped); + assertFullRoundTrip(project); } } From d55762905a0a53252b98f5ae506ab974796cdd7c Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 18:05:00 -0400 Subject: [PATCH 7/7] fix: remove unused local variables flagged by PMD --- .../test/java/io/substrait/expression/LambdaBuilderTest.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index a4c569219..01303d932 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -58,9 +58,6 @@ void nestedLambda() { // outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda. @Test void scopeStepsOutChangesDynamically() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - lb.lambda( List.of(R.I32), outer -> {