Skip to content
4 changes: 0 additions & 4 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -775,10 +775,6 @@ public Type getType() {
return Type.withNullability(false).func(paramTypes, returnType);
}

public static ImmutableExpression.Lambda.Builder builder() {
return ImmutableExpression.Lambda.builder();
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
Expand Down
150 changes: 150 additions & 0 deletions core/src/main/java/io/substrait/expression/LambdaBuilder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package io.substrait.expression;

import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

/**
* Builds lambda expressions with build-time validation of parameter references.
*
* <p>Maintains a stack of lambda parameter scopes. Each call to {@link #lambda} pushes parameters
* onto the stack, builds the body via a callback, and pops. Nested lambdas simply call {@code
* lambda()} again on the same builder.
*
* <p>The callback receives a {@link Scope} handle for creating validated parameter references. The
* correct {@code stepsOut} value is computed automatically from the stack.
*
* <pre>{@code
* LambdaBuilder lb = new LambdaBuilder();
*
* // Simple: (x: i32) -> x
* Expression.Lambda simple = lb.lambda(List.of(R.I32), params -> params.ref(0));
*
* // Nested: (x: i32) -> (y: i64) -> add(x, y)
* Expression.Lambda nested = lb.lambda(List.of(R.I32), outer ->
* lb.lambda(List.of(R.I64), inner ->
* add(outer.ref(0), inner.ref(0))
* )
* );
* }</pre>
*/
public class LambdaBuilder {
private final List<Type.Struct> lambdaContext = new ArrayList<>();

/**
* Builds a lambda expression. The body function receives a {@link Scope} for creating validated
* parameter references. Nested lambdas are built by calling this method again inside the
* callback.
*
* @param paramTypes the lambda's parameter types
* @param bodyFn function that builds the lambda body given a scope handle
* @return the constructed lambda expression
*/
public Expression.Lambda lambda(List<Type> paramTypes, Function<Scope, Expression> bodyFn) {
Type.Struct params = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build();
pushLambdaContext(params);
try {
Scope scope = new Scope(params);
Expression body = bodyFn.apply(scope);
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
} finally {
popLambdaContext();
}
}

/**
* Builds a lambda expression from a pre-built parameter struct. Used by internal converters that
* already have a Type.Struct (e.g., during protobuf deserialization).
*
* @param params the lambda's parameter struct
* @param bodyFn function that builds the lambda body
* @return the constructed lambda expression
*/
public Expression.Lambda lambdaFromStruct(
Type.Struct params, java.util.function.Supplier<Expression> bodyFn) {
pushLambdaContext(params);
try {
Expression body = bodyFn.get();
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
} finally {
popLambdaContext();
}
}

/**
* Resolves the parameter struct for a lambda at the given stepsOut from the current innermost
* scope. Used by internal converters to validate lambda parameter references during
* deserialization.
*
* @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost)
* @return the parameter struct at the target scope level
* @throws IllegalArgumentException if stepsOut exceeds the current nesting depth
*/
public Type.Struct resolveParams(int stepsOut) {
int targetDepth = lambdaContext.size() - stepsOut;
if (targetDepth <= 0 || targetDepth > lambdaContext.size()) {
throw new IllegalArgumentException(
String.format(
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
stepsOut, lambdaContext.size()));
}
return lambdaContext.get(targetDepth - 1);
}

/**
* Pushes a lambda's parameters onto the context stack. This makes the parameters available for
* validation when building the lambda's body, and allows nested lambda parameter references to
* correctly compute their stepsOut values.
*/
private void pushLambdaContext(Type.Struct params) {
lambdaContext.add(params);
}

/**
* Pops the most recently pushed lambda parameters from the context stack. Called after a lambda's
* body has been built, restoring the context to the enclosing lambda's scope.
*/
private void popLambdaContext() {
lambdaContext.remove(lambdaContext.size() - 1);
}

/**
* A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated
* parameter references.
*
* <p>Each Scope captures the depth of the lambdaContext stack at the time it was created. When
* {@link #ref} is called, the Substrait {@code stepsOut} value is computed as the difference
* between the current stack depth and the captured depth. This means the same Scope produces
* different stepsOut values depending on the nesting level at the time of the call, which is what
* allows outer.ref(0) to produce stepsOut=1 when called inside a nested lambda.
*/
public class Scope {
private final Type.Struct params;
private final int depth;

private Scope(Type.Struct params) {
this.params = params;
this.depth = lambdaContext.size();
}

/**
* Computes the number of lambda boundaries between this scope and the current innermost scope.
* This value changes dynamically as nested lambdas are built.
*/
private int stepsOut() {
return lambdaContext.size() - depth;
}

/**
* Creates a validated reference to a parameter of this lambda.
*
* @param paramIndex index of the parameter within this lambda's parameter struct
* @return a {@link FieldReference} pointing to the specified parameter
* @throws IndexOutOfBoundsException if paramIndex is out of bounds
*/
public FieldReference ref(int paramIndex) {
return FieldReference.newLambdaParameterReference(paramIndex, params, stepsOut());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.FieldReference.ReferenceSegment;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.LambdaBuilder;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
Expand Down Expand Up @@ -37,7 +38,7 @@ public class ProtoExpressionConverter {
private final Type.Struct rootType;
private final ProtoTypeConverter protoTypeConverter;
private final ProtoRelConverter protoRelConverter;
private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack();
private final LambdaBuilder lambdaBuilder = new LambdaBuilder();

public ProtoExpressionConverter(
ExtensionLookup lookup,
Expand Down Expand Up @@ -82,7 +83,7 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc
reference.getLambdaParameterReference();

int stepsOut = lambdaParamRef.getStepsOut();
Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut);
Type.Struct lambdaParameters = lambdaBuilder.resolveParams(stepsOut);

// Check for unsupported nested field access
if (reference.getDirectReference().getStructField().hasChild()) {
Expand Down Expand Up @@ -290,16 +291,7 @@ public Type visit(Type.Struct type) throws RuntimeException {
.setStruct(protoLambda.getParameters())
.build());

lambdaParameterStack.push(parameters);

Expression body;
try {
body = from(protoLambda.getBody());
} finally {
lambdaParameterStack.pop();
}

return Expression.Lambda.builder().parameters(parameters).body(body).build();
return lambdaBuilder.lambdaFromStruct(parameters, () -> from(protoLambda.getBody()));
}
// TODO enum.
case ENUM:
Expand Down Expand Up @@ -620,42 +612,4 @@ public Expression.SortField fromSortField(SortField s) {
public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build();
}

/**
* A stack for tracking lambda parameter types during expression parsing.
*
* <p>When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack.
* Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference:
*
* <ul>
* <li>stepsOut=0 refers to the innermost (current) lambda
* <li>stepsOut=1 refers to the next enclosing lambda
* <li>stepsOut=N refers to N levels up
* </ul>
*/
private static class LambdaParameterStack {
private final List<Type.Struct> stack = new ArrayList<>();

void push(Type.Struct parameters) {
stack.add(parameters);
}

void pop() {
if (stack.isEmpty()) {
throw new IllegalArgumentException("Lambda parameter stack is empty");
}
stack.remove(stack.size() - 1);
}

Type.Struct get(int stepsOut) {
int index = stack.size() - 1 - stepsOut;
if (index < 0 || index >= stack.size()) {
throw new IllegalArgumentException(
String.format(
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
stepsOut, stack.size()));
}
return stack.get(index);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.substrait.expression.ExpressionVisitor;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableExpression;
import io.substrait.util.EmptyVisitationContext;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -448,7 +449,10 @@ public Optional<Expression> visit(Expression.Lambda lambda, EmptyVisitationConte
return Optional.empty();
}
return Optional.of(
Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build());
ImmutableExpression.Lambda.builder()
.from(lambda)
.body(newBody.orElse(lambda.body()))
.build());
}

// utilities
Expand Down
107 changes: 107 additions & 0 deletions core/src/test/java/io/substrait/expression/LambdaBuilderTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package io.substrait.expression;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import org.junit.jupiter.api.Test;

/** Tests for {@link LambdaBuilder}. */
class LambdaBuilderTest {

static final TypeCreator R = TypeCreator.REQUIRED;

final LambdaBuilder lb = new LambdaBuilder();

// (x: i32)@p -> p[0]
@Test
void simpleLambda() {
Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0));

Expression.Lambda expected =
ImmutableExpression.Lambda.builder()
.parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build())
.body(
FieldReference.newLambdaParameterReference(
0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 0))
.build();

assertEquals(expected, lambda);
}

// (x: i32)@outer -> (y: i64)@inner -> outer[0]
@Test
void nestedLambda() {
Expression.Lambda lambda =
lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(0)));

Expression.Lambda expectedInner =
ImmutableExpression.Lambda.builder()
.parameters(Type.Struct.builder().nullable(false).addFields(R.I64).build())
.body(
FieldReference.newLambdaParameterReference(
0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 1))
.build();

Expression.Lambda expected =
ImmutableExpression.Lambda.builder()
.parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build())
.body(expectedInner)
.build();

assertEquals(expected, lambda);
}

// Verify that the same scope handle produces different stepsOut values depending on nesting.
// outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda.
@Test
void scopeStepsOutChangesDynamically() {
lb.lambda(
List.of(R.I32),
outer -> {
FieldReference atTopLevel = outer.ref(0);
assertEquals(0, atTopLevel.lambdaParameterReferenceStepsOut().orElse(-1));

lb.lambda(
List.of(R.I64),
inner -> {
FieldReference atNestedLevel = outer.ref(0);
assertEquals(1, atNestedLevel.lambdaParameterReferenceStepsOut().orElse(-1));
return inner.ref(0);
});

return atTopLevel;
});
}

// (x: i32)@p -> p[5] — only 1 param, index 5 is out of bounds
@Test
void invalidFieldIndex_outOfBounds() {
assertThrows(
IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(5)));
}

// (x: i32)@p -> p[-1] — negative index
@Test
void negativeFieldIndex() {
assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1)));
}

// (x: i32)@outer -> (y: i64)@inner -> outer[5] — outer only has 1 param
@Test
void nestedOuterFieldIndexOutOfBounds() {
assertThrows(
IndexOutOfBoundsException.class,
() -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(5))));
}

// (x: i32)@outer -> (y: i64)@inner -> inner[3] — inner only has 1 param
@Test
void nestedInnerFieldIndexOutOfBounds() {
assertThrows(
IndexOutOfBoundsException.class,
() -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(3))));
}
}
Loading