Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -590,4 +590,17 @@ public O visit(Expression.ScalarSubquery expr, C context) throws E {
public O visit(Expression.InPredicate expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a dynamic parameter expression.
*
* @param expr the dynamic parameter
* @param context the visitation context
* @return the visit result
* @throws E if visitation fails
*/
@Override
public O visit(Expression.DynamicParameter expr, C context) throws E {
return visitFallback(expr, context);
}
}
22 changes: 22 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,28 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

@Value.Immutable
abstract class DynamicParameter implements Expression {
public abstract Type type();

public abstract int parameterReference();

@Override
public Type getType() {
return type();
}

public static ImmutableExpression.DynamicParameter.Builder builder() {
return ImmutableExpression.DynamicParameter.builder();
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

enum PredicateOp {
PREDICATE_OP_UNSPECIFIED(
io.substrait.proto.Expression.Subquery.SetPredicate.PredicateOp.PREDICATE_OP_UNSPECIFIED),
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -460,4 +460,14 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
* @throws E on visit failure
*/
R visit(Expression.InPredicate expr, C context) throws E;

/**
* Visit a dynamic parameter expression.
*
* @param expr the dynamic parameter expression
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(Expression.DynamicParameter expr, C context) throws E;
}
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,18 @@ public Expression visit(
.build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.DynamicParameter expr, EmptyVisitationContext context)
throws RuntimeException {
return Expression.newBuilder()
.setDynamicParameter(
io.substrait.proto.DynamicParameter.newBuilder()
.setType(toProto(expr.type()))
.setParameterReference(expr.parameterReference()))
.build();
}

public static class BoundConverter
implements WindowBound.WindowBoundVisitor<Expression.WindowFunction.Bound, RuntimeException> {
private static final BoundConverter TO_BOUND_VISITOR = new BoundConverter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,15 @@ public Type visit(Type.Struct type) throws RuntimeException {
}
}

case DYNAMIC_PARAMETER:
{
io.substrait.proto.DynamicParameter dp = expr.getDynamicParameter();
return Expression.DynamicParameter.builder()
.type(protoTypeConverter.from(dp.getType()))
.parameterReference(dp.getParameterReference())
.build();
}

// TODO enum.
case ENUM:
throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,12 @@ public Optional<Expression> visit(
.build());
}

@Override
public Optional<Expression> visit(
Expression.DynamicParameter expr, EmptyVisitationContext context) throws E {
return Optional.empty();
}

// utilities

protected Optional<List<Expression>> visitExprList(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package io.substrait.type.proto;

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.TestBase;
import io.substrait.expression.Expression;
import io.substrait.type.TypeCreator;
import org.junit.jupiter.api.Test;

class DynamicParameterRoundtripTest extends TestBase {

@Test
void dynamicParameterI64() {
Expression.DynamicParameter dp =
Expression.DynamicParameter.builder()
.type(TypeCreator.REQUIRED.I64)
.parameterReference(0)
.build();

assertEquals(TypeCreator.REQUIRED.I64, dp.getType());
verifyRoundTrip(dp);
}

@Test
void dynamicParameterNullableString() {
Expression.DynamicParameter dp =
Expression.DynamicParameter.builder()
.type(TypeCreator.NULLABLE.STRING)
.parameterReference(1)
.build();

assertEquals(TypeCreator.NULLABLE.STRING, dp.getType());
verifyRoundTrip(dp);
}

@Test
void dynamicParameterFP64() {
Expression.DynamicParameter dp =
Expression.DynamicParameter.builder()
.type(TypeCreator.REQUIRED.FP64)
.parameterReference(2)
.build();

assertEquals(TypeCreator.REQUIRED.FP64, dp.getType());
verifyRoundTrip(dp);
}

@Test
void dynamicParameterI32Nullable() {
Expression.DynamicParameter dp =
Expression.DynamicParameter.builder()
.type(TypeCreator.NULLABLE.I32)
.parameterReference(42)
.build();

assertEquals(42, dp.parameterReference());
verifyRoundTrip(dp);
}

@Test
void dynamicParameterDate() {
Expression.DynamicParameter dp =
Expression.DynamicParameter.builder()
.type(TypeCreator.REQUIRED.DATE)
.parameterReference(3)
.build();

assertEquals(TypeCreator.REQUIRED.DATE, dp.getType());
verifyRoundTrip(dp);
}

@Test
void dynamicParameterBoolean() {
Expression.DynamicParameter dp =
Expression.DynamicParameter.builder()
.type(TypeCreator.REQUIRED.BOOLEAN)
.parameterReference(0)
.build();

assertEquals(TypeCreator.REQUIRED.BOOLEAN, dp.getType());
verifyRoundTrip(dp);
}

@Test
void dynamicParameterDecimal() {
Expression.DynamicParameter dp =
Expression.DynamicParameter.builder()
.type(TypeCreator.REQUIRED.decimal(10, 2))
.parameterReference(5)
.build();

assertEquals(TypeCreator.REQUIRED.decimal(10, 2), dp.getType());
verifyRoundTrip(dp);
}

@Test
void dynamicParameterTimestamp() {
Expression.DynamicParameter dp =
Expression.DynamicParameter.builder()
.type(TypeCreator.NULLABLE.TIMESTAMP)
.parameterReference(7)
.build();

assertEquals(TypeCreator.NULLABLE.TIMESTAMP, dp.getType());
verifyRoundTrip(dp);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,10 @@ public String visit(EmptyMapLiteral expr, EmptyVisitationContext context)
throws RuntimeException {
return "<EmptyMapLiteral>";
}

@Override
public String visit(Expression.DynamicParameter expr, EmptyVisitationContext context)
throws RuntimeException {
return "<DynamicParameter " + expr.parameterReference() + " " + expr.type() + ">";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,12 @@ public RexNode visit(SetPredicate expr, Context context) throws RuntimeException
}
}

@Override
public RexNode visit(Expression.DynamicParameter expr, Context context) throws RuntimeException {
RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.type());
return rexBuilder.makeDynamicParam(calciteType, expr.parameterReference());
}

/**
* Helper method to create a Calcite ROW expression for encoding UDT struct literals.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ public Expression visitCorrelVariable(RexCorrelVariable correlVariable) {

@Override
public Expression visitDynamicParam(RexDynamicParam dynamicParam) {
throw new UnsupportedOperationException("RexDynamicParam not supported");
return Expression.DynamicParameter.builder()
.type(typeConverter.toSubstrait(dynamicParam.getType()))
.parameterReference(dynamicParam.getIndex())
.build();
}

@Override
Expand Down
Loading
Loading