diff --git a/mlir/example/CMakeLists.txt b/mlir/example/CMakeLists.txt index b032946..da5afa2 100644 --- a/mlir/example/CMakeLists.txt +++ b/mlir/example/CMakeLists.txt @@ -42,7 +42,6 @@ add_subdirectory(Ch4) add_subdirectory(Ch5) add_subdirectory(Ch6) add_subdirectory(Ch7) -add_subdirectory(Ch8) add_subdirectory(transform_Ch2) add_subdirectory(transform_Ch3) add_subdirectory(transform_Ch4) diff --git a/mlir/example/Ch1/include/toy/Lexer.h b/mlir/example/Ch1/include/toy/Lexer.h index ecbb3b4..d420a7e 100644 --- a/mlir/example/Ch1/include/toy/Lexer.h +++ b/mlir/example/Ch1/include/toy/Lexer.h @@ -15,6 +15,7 @@ #include "llvm/ADT/StringRef.h" +#include #include #include diff --git a/mlir/example/Ch1/parser/AST.cpp b/mlir/example/Ch1/parser/AST.cpp index 2546f2a..8416424 100644 --- a/mlir/example/Ch1/parser/AST.cpp +++ b/mlir/example/Ch1/parser/AST.cpp @@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *litOrNum) { +static void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal if (auto *num = llvm::dyn_cast(litOrNum)) { llvm::errs() << num->getValue(); diff --git a/mlir/example/Ch1/toyc.cpp b/mlir/example/Ch1/toyc.cpp index fb7b484..b9f3a2d 100644 --- a/mlir/example/Ch1/toyc.cpp +++ b/mlir/example/Ch1/toyc.cpp @@ -39,7 +39,8 @@ static cl::opt cl::values(clEnumValN(DumpAST, "ast", "output the AST dump"))); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. -std::unique_ptr parseInputFile(llvm::StringRef filename) { +static std::unique_ptr +parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { diff --git a/mlir/example/Ch2/include/toy/Lexer.h b/mlir/example/Ch2/include/toy/Lexer.h index 3c59cd9..22822cc 100644 --- a/mlir/example/Ch2/include/toy/Lexer.h +++ b/mlir/example/Ch2/include/toy/Lexer.h @@ -15,6 +15,7 @@ #include "llvm/ADT/StringRef.h" +#include #include #include diff --git a/mlir/example/Ch2/include/toy/Ops.td b/mlir/example/Ch2/include/toy/Ops.td index 1a1b136..91bf83a 100644 --- a/mlir/example/Ch2/include/toy/Ops.td +++ b/mlir/example/Ch2/include/toy/Ops.td @@ -70,7 +70,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ @@ -297,7 +297,7 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, // Allow building a ReturnOp with no return operand. let builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + OpBuilder<(ins), [{ build($_builder, $_state, {}); }]> ]; // Provide extra utility definitions on the c++ operation class definition. diff --git a/mlir/example/Ch2/mlir/MLIRGen.cpp b/mlir/example/Ch2/mlir/MLIRGen.cpp index bf4c099..a9592bc 100644 --- a/mlir/example/Ch2/mlir/MLIRGen.cpp +++ b/mlir/example/Ch2/mlir/MLIRGen.cpp @@ -120,9 +120,9 @@ class MLIRGenImpl { // Arguments type are uniformly unranked tensors. llvm::SmallVector argTypes(proto.getArgs().size(), getType(VarType{})); - auto funcType = builder.getFunctionType(argTypes, std::nullopt); - return builder.create(location, proto.getName(), - funcType); + auto funcType = builder.getFunctionType(argTypes, {}); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ class MLIRGenImpl { if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { - builder.create(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -202,9 +202,9 @@ class MLIRGenImpl { // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -235,8 +235,8 @@ class MLIRGenImpl { } // Otherwise, this return operation has zero operands. - builder.create(location, - expr ? ArrayRef(expr) : ArrayRef()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef()); return mlir::success(); } @@ -264,8 +264,7 @@ class MLIRGenImpl { // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the @@ -280,7 +279,7 @@ class MLIRGenImpl { // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -325,13 +324,13 @@ class MLIRGenImpl { "does not accept multiple arguments"); return nullptr; } - return builder.create(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -341,13 +340,13 @@ class MLIRGenImpl { if (!arg) return mlir::failure(); - builder.create(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -391,8 +390,8 @@ class MLIRGenImpl { // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/example/Ch2/parser/AST.cpp b/mlir/example/Ch2/parser/AST.cpp index 2546f2a..8416424 100644 --- a/mlir/example/Ch2/parser/AST.cpp +++ b/mlir/example/Ch2/parser/AST.cpp @@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *litOrNum) { +static void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal if (auto *num = llvm::dyn_cast(litOrNum)) { llvm::errs() << num->getValue(); diff --git a/mlir/example/Ch2/toyc.cpp b/mlir/example/Ch2/toyc.cpp index e33b49b..a60738d 100644 --- a/mlir/example/Ch2/toyc.cpp +++ b/mlir/example/Ch2/toyc.cpp @@ -58,7 +58,8 @@ static cl::opt emitAction( cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. -std::unique_ptr parseInputFile(llvm::StringRef filename) { +static std::unique_ptr +parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { @@ -71,7 +72,7 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.parseModule(); } -int dumpMLIR() { +static int dumpMLIR() { mlir::MLIRContext context; // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); @@ -112,7 +113,7 @@ int dumpMLIR() { return 0; } -int dumpAST() { +static int dumpAST() { if (inputType == InputType::MLIR) { llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; return 5; diff --git a/mlir/example/Ch3/include/toy/Lexer.h b/mlir/example/Ch3/include/toy/Lexer.h index 3c59cd9..22822cc 100644 --- a/mlir/example/Ch3/include/toy/Lexer.h +++ b/mlir/example/Ch3/include/toy/Lexer.h @@ -15,6 +15,7 @@ #include "llvm/ADT/StringRef.h" +#include #include #include diff --git a/mlir/example/Ch3/include/toy/Ops.td b/mlir/example/Ch3/include/toy/Ops.td index 021802b..027b076 100644 --- a/mlir/example/Ch3/include/toy/Ops.td +++ b/mlir/example/Ch3/include/toy/Ops.td @@ -69,7 +69,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ @@ -298,7 +298,7 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, // Allow building a ReturnOp with no return operand. let builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + OpBuilder<(ins), [{ build($_builder, $_state, {}); }]> ]; // Provide extra utility definitions on the c++ operation class definition. diff --git a/mlir/example/Ch3/mlir/MLIRGen.cpp b/mlir/example/Ch3/mlir/MLIRGen.cpp index bf4c099..8c21951 100644 --- a/mlir/example/Ch3/mlir/MLIRGen.cpp +++ b/mlir/example/Ch3/mlir/MLIRGen.cpp @@ -120,9 +120,9 @@ class MLIRGenImpl { // Arguments type are uniformly unranked tensors. llvm::SmallVector argTypes(proto.getArgs().size(), getType(VarType{})); - auto funcType = builder.getFunctionType(argTypes, std::nullopt); - return builder.create(location, proto.getName(), - funcType); + auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ class MLIRGenImpl { if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { - builder.create(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -202,9 +202,9 @@ class MLIRGenImpl { // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -235,8 +235,8 @@ class MLIRGenImpl { } // Otherwise, this return operation has zero operands. - builder.create(location, - expr ? ArrayRef(expr) : ArrayRef()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef()); return mlir::success(); } @@ -264,8 +264,7 @@ class MLIRGenImpl { // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the @@ -280,7 +279,7 @@ class MLIRGenImpl { // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -325,13 +324,13 @@ class MLIRGenImpl { "does not accept multiple arguments"); return nullptr; } - return builder.create(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -341,13 +340,13 @@ class MLIRGenImpl { if (!arg) return mlir::failure(); - builder.create(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -391,8 +390,8 @@ class MLIRGenImpl { // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/example/Ch3/parser/AST.cpp b/mlir/example/Ch3/parser/AST.cpp index 2546f2a..8416424 100644 --- a/mlir/example/Ch3/parser/AST.cpp +++ b/mlir/example/Ch3/parser/AST.cpp @@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *litOrNum) { +static void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal if (auto *num = llvm::dyn_cast(litOrNum)) { llvm::errs() << num->getValue(); diff --git a/mlir/example/Ch3/toyc.cpp b/mlir/example/Ch3/toyc.cpp index f8aa846..3094935 100644 --- a/mlir/example/Ch3/toyc.cpp +++ b/mlir/example/Ch3/toyc.cpp @@ -64,7 +64,8 @@ static cl::opt emitAction( static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. -std::unique_ptr parseInputFile(llvm::StringRef filename) { +static std::unique_ptr +parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { @@ -77,8 +78,8 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.parseModule(); } -int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, - mlir::OwningOpRef &module) { +static int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, + mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { @@ -107,7 +108,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, return 0; } -int dumpMLIR() { +static int dumpMLIR() { mlir::MLIRContext context; // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); @@ -134,7 +135,7 @@ int dumpMLIR() { return 0; } -int dumpAST() { +static int dumpAST() { if (inputType == InputType::MLIR) { llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; return 5; diff --git a/mlir/example/Ch4/include/toy/Lexer.h b/mlir/example/Ch4/include/toy/Lexer.h index 3c59cd9..22822cc 100644 --- a/mlir/example/Ch4/include/toy/Lexer.h +++ b/mlir/example/Ch4/include/toy/Lexer.h @@ -15,6 +15,7 @@ #include "llvm/ADT/StringRef.h" +#include #include #include diff --git a/mlir/example/Ch4/include/toy/Ops.td b/mlir/example/Ch4/include/toy/Ops.td index 075fd1a..6c6b739 100644 --- a/mlir/example/Ch4/include/toy/Ops.td +++ b/mlir/example/Ch4/include/toy/Ops.td @@ -72,7 +72,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ @@ -215,7 +215,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); @@ -330,7 +335,7 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, // Allow building a ReturnOp with no return operand. let builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + OpBuilder<(ins), [{ build($_builder, $_state, {}); }]> ]; // Provide extra utility definitions on the c++ operation class definition. diff --git a/mlir/example/Ch4/mlir/Dialect.cpp b/mlir/example/Ch4/mlir/Dialect.cpp index 6c6cdd9..1e5e672 100644 --- a/mlir/example/Ch4/mlir/Dialect.cpp +++ b/mlir/example/Ch4/mlir/Dialect.cpp @@ -91,7 +91,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); + return CastOp::create(builder, conversionLoc, resultType, input); } }; @@ -206,7 +206,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(getResult().getType()); + auto resultType = + llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -333,7 +334,7 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() { /// Set the callee for the generic call operation, this is required by the call /// interface. void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); } /// Get the argument operands to the called function, this is required by the @@ -395,7 +396,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); diff --git a/mlir/example/Ch4/mlir/MLIRGen.cpp b/mlir/example/Ch4/mlir/MLIRGen.cpp index b56e2f7..6b7ab40 100644 --- a/mlir/example/Ch4/mlir/MLIRGen.cpp +++ b/mlir/example/Ch4/mlir/MLIRGen.cpp @@ -120,9 +120,9 @@ class MLIRGenImpl { // Arguments type are uniformly unranked tensors. llvm::SmallVector argTypes(proto.getArgs().size(), getType(VarType{})); - auto funcType = builder.getFunctionType(argTypes, std::nullopt); - return builder.create(location, proto.getName(), - funcType); + auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ class MLIRGenImpl { if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { - builder.create(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -206,9 +206,9 @@ class MLIRGenImpl { // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -239,8 +239,8 @@ class MLIRGenImpl { } // Otherwise, this return operation has zero operands. - builder.create(location, - expr ? ArrayRef(expr) : ArrayRef()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef()); return mlir::success(); } @@ -268,8 +268,7 @@ class MLIRGenImpl { // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the @@ -284,7 +283,7 @@ class MLIRGenImpl { // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -329,13 +328,13 @@ class MLIRGenImpl { "does not accept multiple arguments"); return nullptr; } - return builder.create(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -345,13 +344,13 @@ class MLIRGenImpl { if (!arg) return mlir::failure(); - builder.create(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -395,8 +394,8 @@ class MLIRGenImpl { // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/example/Ch4/mlir/ShapeInferencePass.cpp b/mlir/example/Ch4/mlir/ShapeInferencePass.cpp index a9e995e..a552e1f 100644 --- a/mlir/example/Ch4/mlir/ShapeInferencePass.cpp +++ b/mlir/example/Ch4/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include @@ -55,6 +55,7 @@ namespace { struct ShapeInferencePass : public mlir::PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) + StringRef getArgument() const override { return "toy-shape-inference"; } void runOnOperation() override { auto f = getOperation(); @@ -80,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/example/Ch4/parser/AST.cpp b/mlir/example/Ch4/parser/AST.cpp index 2546f2a..8416424 100644 --- a/mlir/example/Ch4/parser/AST.cpp +++ b/mlir/example/Ch4/parser/AST.cpp @@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *litOrNum) { +static void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal if (auto *num = llvm::dyn_cast(litOrNum)) { llvm::errs() << num->getValue(); diff --git a/mlir/example/Ch4/toyc.cpp b/mlir/example/Ch4/toyc.cpp index ae02bc4..36816f0 100644 --- a/mlir/example/Ch4/toyc.cpp +++ b/mlir/example/Ch4/toyc.cpp @@ -65,7 +65,8 @@ static cl::opt emitAction( static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. -std::unique_ptr parseInputFile(llvm::StringRef filename) { +static std::unique_ptr +parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { @@ -78,8 +79,8 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.parseModule(); } -int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, - mlir::OwningOpRef &module) { +static int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, + mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { @@ -108,7 +109,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, return 0; } -int dumpMLIR() { +static int dumpMLIR() { mlir::MLIRContext context; // Load our Dialect in this MLIR Context. context.getOrLoadDialect(); @@ -143,7 +144,7 @@ int dumpMLIR() { return 0; } -int dumpAST() { +static int dumpAST() { if (inputType == InputType::MLIR) { llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; return 5; diff --git a/mlir/example/Ch5/include/toy/Lexer.h b/mlir/example/Ch5/include/toy/Lexer.h index 3c59cd9..22822cc 100644 --- a/mlir/example/Ch5/include/toy/Lexer.h +++ b/mlir/example/Ch5/include/toy/Lexer.h @@ -15,6 +15,7 @@ #include "llvm/ADT/StringRef.h" +#include #include #include diff --git a/mlir/example/Ch5/include/toy/Ops.td b/mlir/example/Ch5/include/toy/Ops.td index ec6762f..6a136ec 100644 --- a/mlir/example/Ch5/include/toy/Ops.td +++ b/mlir/example/Ch5/include/toy/Ops.td @@ -72,7 +72,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ @@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); @@ -330,7 +335,7 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, // Allow building a ReturnOp with no return operand. let builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + OpBuilder<(ins), [{ build($_builder, $_state, {}); }]> ]; // Provide extra utility definitions on the c++ operation class definition. diff --git a/mlir/example/Ch5/mlir/Dialect.cpp b/mlir/example/Ch5/mlir/Dialect.cpp index 72072f9..69fb69f 100644 --- a/mlir/example/Ch5/mlir/Dialect.cpp +++ b/mlir/example/Ch5/mlir/Dialect.cpp @@ -91,7 +91,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); + return CastOp::create(builder, conversionLoc, resultType, input); } }; @@ -206,7 +206,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(getResult().getType()); + auto resultType = + llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -333,7 +334,7 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() { /// Set the callee for the generic call operation, this is required by the call /// interface. void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); } /// Get the argument operands to the called function, this is required by the @@ -395,7 +396,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); diff --git a/mlir/example/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/example/Ch5/mlir/LowerToAffineLoops.cpp index 7413214..2969d3a 100644 --- a/mlir/example/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/example/Ch5/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -55,7 +55,7 @@ static MemRefType convertTensorToMemRef(RankedTensorType type) { /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter) { - auto alloc = rewriter.create(loc, type); + auto alloc = memref::AllocOp::create(rewriter, loc, type); // Make sure to allocate at the beginning of the block. auto *parentBlock = alloc->getBlock(); @@ -63,21 +63,19 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, // Make sure to deallocate this alloc at the end of the block. This is fine // as toy functions have no control flow. - auto dealloc = rewriter.create(loc, alloc); + auto dealloc = memref::DeallocOp::create(rewriter, loc, alloc); dealloc->moveBefore(&parentBlock->back()); return alloc; } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,12 +93,12 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, - // and the loop induction variables. This function will return the value - // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); - nestedBuilder.create(loc, valueToStore, alloc, - ivs); + // Call the processing function with the rewriter and the loop + // induction variables. This function will return the value to store at + // the current index. + Value valueToStore = processIteration(nestedBuilder, ivs); + affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, + ivs); }); // Replace this operation with the generated alloc. @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = builder.create( - loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = builder.create( - loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return builder.create(loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering; using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -174,11 +165,11 @@ struct ConstantOpLowering : public OpRewritePattern { if (!valueShape.empty()) { for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) constantIndices.push_back( - rewriter.create(loc, i)); + arith::ConstantIndexOp::create(rewriter, loc, i)); } else { // This is the case of a tensor of rank 0. constantIndices.push_back( - rewriter.create(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -191,9 +182,9 @@ struct ConstantOpLowering : public OpRewritePattern { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. if (dimension == valueShape.size()) { - rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, - llvm::ArrayRef(indices)); + affine::AffineStoreOp::create( + rewriter, loc, arith::ConstantOp::create(rewriter, loc, *valueIt++), + alloc, llvm::ArrayRef(indices)); return; } @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern { @@ -238,8 +229,8 @@ struct FuncOpLowering : public OpConversionPattern { } // Create a new non-toy function, with the same region. - auto func = rewriter.create(op.getLoc(), op.getName(), - op.getFunctionType()); + auto func = mlir::func::FuncOp::create(rewriter, op.getLoc(), op.getName(), + op.getFunctionType()); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); return success(); @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); - return builder.create(loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; @@ -328,6 +312,7 @@ namespace { struct ToyToAffineLoweringPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToAffineLoweringPass) + StringRef getArgument() const override { return "toy-to-affine"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert argTypes(proto.getArgs().size(), getType(VarType{})); - auto funcType = builder.getFunctionType(argTypes, std::nullopt); - return builder.create(location, proto.getName(), - funcType); + auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ class MLIRGenImpl { if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { - builder.create(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -206,9 +206,9 @@ class MLIRGenImpl { // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -239,8 +239,8 @@ class MLIRGenImpl { } // Otherwise, this return operation has zero operands. - builder.create(location, - expr ? ArrayRef(expr) : ArrayRef()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef()); return mlir::success(); } @@ -268,8 +268,7 @@ class MLIRGenImpl { // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the @@ -284,7 +283,7 @@ class MLIRGenImpl { // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -329,13 +328,13 @@ class MLIRGenImpl { "does not accept multiple arguments"); return nullptr; } - return builder.create(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -345,13 +344,13 @@ class MLIRGenImpl { if (!arg) return mlir::failure(); - builder.create(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -395,8 +394,8 @@ class MLIRGenImpl { // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/example/Ch5/mlir/ShapeInferencePass.cpp b/mlir/example/Ch5/mlir/ShapeInferencePass.cpp index a9e995e..a552e1f 100644 --- a/mlir/example/Ch5/mlir/ShapeInferencePass.cpp +++ b/mlir/example/Ch5/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include @@ -55,6 +55,7 @@ namespace { struct ShapeInferencePass : public mlir::PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) + StringRef getArgument() const override { return "toy-shape-inference"; } void runOnOperation() override { auto f = getOperation(); @@ -80,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/example/Ch5/parser/AST.cpp b/mlir/example/Ch5/parser/AST.cpp index 2546f2a..8416424 100644 --- a/mlir/example/Ch5/parser/AST.cpp +++ b/mlir/example/Ch5/parser/AST.cpp @@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *litOrNum) { +static void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal if (auto *num = llvm::dyn_cast(litOrNum)) { llvm::errs() << num->getValue(); diff --git a/mlir/example/Ch5/toyc.cpp b/mlir/example/Ch5/toyc.cpp index 6a0c631..d62a1c0 100644 --- a/mlir/example/Ch5/toyc.cpp +++ b/mlir/example/Ch5/toyc.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Diagnostics.h" #include "toy/AST.h" #include "toy/Dialect.h" @@ -19,7 +20,7 @@ #include "toy/Parser.h" #include "toy/Passes.h" -#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Passes.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" @@ -70,7 +71,8 @@ static cl::opt emitAction( static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. -std::unique_ptr parseInputFile(llvm::StringRef filename) { +static std::unique_ptr +parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { @@ -83,8 +85,8 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.parseModule(); } -int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, - mlir::OwningOpRef &module) { +static int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, + mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { @@ -113,7 +115,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context, return 0; } -int dumpMLIR() { +static int dumpMLIR() { mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); @@ -170,7 +172,7 @@ int dumpMLIR() { return 0; } -int dumpAST() { +static int dumpAST() { if (inputType == InputType::MLIR) { llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; return 5; diff --git a/mlir/example/Ch6/include/toy/Lexer.h b/mlir/example/Ch6/include/toy/Lexer.h index 3c59cd9..22822cc 100644 --- a/mlir/example/Ch6/include/toy/Lexer.h +++ b/mlir/example/Ch6/include/toy/Lexer.h @@ -15,6 +15,7 @@ #include "llvm/ADT/StringRef.h" +#include #include #include diff --git a/mlir/example/Ch6/include/toy/Ops.td b/mlir/example/Ch6/include/toy/Ops.td index a52bebc..897b36d 100644 --- a/mlir/example/Ch6/include/toy/Ops.td +++ b/mlir/example/Ch6/include/toy/Ops.td @@ -72,7 +72,7 @@ def ConstantOp : Toy_Op<"constant", [Pure]> { // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ @@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); @@ -330,7 +335,7 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, // Allow building a ReturnOp with no return operand. let builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + OpBuilder<(ins), [{ build($_builder, $_state, {}); }]> ]; // Provide extra utility definitions on the c++ operation class definition. diff --git a/mlir/example/Ch6/mlir/Dialect.cpp b/mlir/example/Ch6/mlir/Dialect.cpp index 72072f9..69fb69f 100644 --- a/mlir/example/Ch6/mlir/Dialect.cpp +++ b/mlir/example/Ch6/mlir/Dialect.cpp @@ -91,7 +91,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); + return CastOp::create(builder, conversionLoc, resultType, input); } }; @@ -206,7 +206,8 @@ void ConstantOp::print(mlir::OpAsmPrinter &printer) { llvm::LogicalResult ConstantOp::verify() { // If the return type of the constant is not an unranked tensor, the shape // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(getResult().getType()); + auto resultType = + llvm::dyn_cast(getResult().getType()); if (!resultType) return success(); @@ -333,7 +334,7 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() { /// Set the callee for the generic call operation, this is required by the call /// interface. void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); } /// Get the argument operands to the called function, this is required by the @@ -395,7 +396,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); diff --git a/mlir/example/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/example/Ch6/mlir/LowerToAffineLoops.cpp index 7413214..2969d3a 100644 --- a/mlir/example/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/example/Ch6/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -55,7 +55,7 @@ static MemRefType convertTensorToMemRef(RankedTensorType type) { /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter) { - auto alloc = rewriter.create(loc, type); + auto alloc = memref::AllocOp::create(rewriter, loc, type); // Make sure to allocate at the beginning of the block. auto *parentBlock = alloc->getBlock(); @@ -63,21 +63,19 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, // Make sure to deallocate this alloc at the end of the block. This is fine // as toy functions have no control flow. - auto dealloc = rewriter.create(loc, alloc); + auto dealloc = memref::DeallocOp::create(rewriter, loc, alloc); dealloc->moveBefore(&parentBlock->back()); return alloc; } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,12 +93,12 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, - // and the loop induction variables. This function will return the value - // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); - nestedBuilder.create(loc, valueToStore, alloc, - ivs); + // Call the processing function with the rewriter and the loop + // induction variables. This function will return the value to store at + // the current index. + Value valueToStore = processIteration(nestedBuilder, ivs); + affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, + ivs); }); // Replace this operation with the generated alloc. @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = builder.create( - loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = builder.create( - loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return builder.create(loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering; using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -174,11 +165,11 @@ struct ConstantOpLowering : public OpRewritePattern { if (!valueShape.empty()) { for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) constantIndices.push_back( - rewriter.create(loc, i)); + arith::ConstantIndexOp::create(rewriter, loc, i)); } else { // This is the case of a tensor of rank 0. constantIndices.push_back( - rewriter.create(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -191,9 +182,9 @@ struct ConstantOpLowering : public OpRewritePattern { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. if (dimension == valueShape.size()) { - rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, - llvm::ArrayRef(indices)); + affine::AffineStoreOp::create( + rewriter, loc, arith::ConstantOp::create(rewriter, loc, *valueIt++), + alloc, llvm::ArrayRef(indices)); return; } @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern { @@ -238,8 +229,8 @@ struct FuncOpLowering : public OpConversionPattern { } // Create a new non-toy function, with the same region. - auto func = rewriter.create(op.getLoc(), op.getName(), - op.getFunctionType()); + auto func = mlir::func::FuncOp::create(rewriter, op.getLoc(), op.getName(), + op.getFunctionType()); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); return success(); @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); - return builder.create(loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; @@ -328,6 +312,7 @@ namespace { struct ToyToAffineLoweringPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToAffineLoweringPass) + StringRef getArgument() const override { return "toy-to-affine"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert { public: - explicit PrintOpLowering(MLIRContext *context) - : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *context = rewriter.getContext(); auto memRefType = llvm::cast((*op->operand_type_begin())); @@ -86,13 +85,13 @@ class PrintOpLowering : public ConversionPattern { // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto upperBound = - rewriter.create(loc, memRefShape[i]); - auto step = rewriter.create(loc, 1); + arith::ConstantIndexOp::create(rewriter, loc, memRefShape[i]); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto loop = - rewriter.create(loc, lowerBound, upperBound, step); - for (Operation &nested : *loop.getBody()) + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); + for (Operation &nested : make_early_inc_range(*loop.getBody())) rewriter.eraseOp(&nested); loopIvs.push_back(loop.getInductionVar()); @@ -101,19 +100,17 @@ class PrintOpLowering : public ConversionPattern { // Insert a newline after each of the inner dimensions of the shape. if (i != e - 1) - rewriter.create(loc, getPrintfType(context), printfRef, - newLineCst); - rewriter.create(loc); + LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, + newLineCst); + scf::YieldOp::create(rewriter, loc); rewriter.setInsertionPointToStart(loop.getBody()); } // Generate a call to printf for the current element of the loop. - auto printOp = cast(op); auto elementLoad = - rewriter.create(loc, printOp.getInput(), loopIvs); - rewriter.create( - loc, getPrintfType(context), printfRef, - ArrayRef({formatSpecifierCst, elementLoad})); + memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs); + LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, + ArrayRef({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); @@ -142,8 +139,8 @@ class PrintOpLowering : public ConversionPattern { // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), "printf", - getPrintfType(context)); + LLVM::LLVMFuncOp::create(rewriter, module.getLoc(), "printf", + getPrintfType(context)); return SymbolRefAttr::get(context, "printf"); } @@ -159,19 +156,19 @@ class PrintOpLowering : public ConversionPattern { builder.setInsertionPointToStart(module.getBody()); auto type = LLVM::LLVMArrayType::get( IntegerType::get(builder.getContext(), 8), value.size()); - global = builder.create(loc, type, /*isConstant=*/true, - LLVM::Linkage::Internal, name, - builder.getStringAttr(value), - /*alignment=*/0); + global = LLVM::GlobalOp::create(builder, loc, type, /*isConstant=*/true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value), + /*alignment=*/0); } // Get the pointer to the first character in the global string. - Value globalPtr = builder.create(loc, global); - Value cst0 = builder.create(loc, builder.getI64Type(), - builder.getIndexAttr(0)); - return builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), - globalPtr, ArrayRef({cst0, cst0})); + Value globalPtr = LLVM::AddressOfOp::create(builder, loc, global); + Value cst0 = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getIndexAttr(0)); + return LLVM::GEPOp::create( + builder, loc, LLVM::LLVMPointerType::get(builder.getContext()), + global.getType(), globalPtr, ArrayRef({cst0, cst0})); } }; } // namespace @@ -184,6 +181,7 @@ namespace { struct ToyToLLVMLoweringPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToLLVMLoweringPass) + StringRef getArgument() const override { return "toy-to-llvm"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); diff --git a/mlir/example/Ch6/mlir/MLIRGen.cpp b/mlir/example/Ch6/mlir/MLIRGen.cpp index b56e2f7..6b7ab40 100644 --- a/mlir/example/Ch6/mlir/MLIRGen.cpp +++ b/mlir/example/Ch6/mlir/MLIRGen.cpp @@ -120,9 +120,9 @@ class MLIRGenImpl { // Arguments type are uniformly unranked tensors. llvm::SmallVector argTypes(proto.getArgs().size(), getType(VarType{})); - auto funcType = builder.getFunctionType(argTypes, std::nullopt); - return builder.create(location, proto.getName(), - funcType); + auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -166,7 +166,7 @@ class MLIRGenImpl { if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { - builder.create(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -206,9 +206,9 @@ class MLIRGenImpl { // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -239,8 +239,8 @@ class MLIRGenImpl { } // Otherwise, this return operation has zero operands. - builder.create(location, - expr ? ArrayRef(expr) : ArrayRef()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef()); return mlir::success(); } @@ -268,8 +268,7 @@ class MLIRGenImpl { // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the @@ -284,7 +283,7 @@ class MLIRGenImpl { // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Recursive helper function to accumulate the data that compose an array @@ -329,13 +328,13 @@ class MLIRGenImpl { "does not accept multiple arguments"); return nullptr; } - return builder.create(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to // user-defined functions are mapped to a custom call that takes the callee // name as an attribute. - return builder.create(location, callee, operands); + return GenericCallOp::create(builder, location, callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -345,13 +344,13 @@ class MLIRGenImpl { if (!arg) return mlir::failure(); - builder.create(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -395,8 +394,8 @@ class MLIRGenImpl { // with specific shape, we emit a "reshape" operation. It will get // optimized out later as needed. if (!vardecl.getType().shape.empty()) { - value = builder.create(loc(vardecl.loc()), - getType(vardecl.getType()), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(vardecl.getType()), value); } // Register the value in the symbol table. diff --git a/mlir/example/Ch6/mlir/ShapeInferencePass.cpp b/mlir/example/Ch6/mlir/ShapeInferencePass.cpp index a9e995e..a552e1f 100644 --- a/mlir/example/Ch6/mlir/ShapeInferencePass.cpp +++ b/mlir/example/Ch6/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include @@ -55,6 +55,7 @@ namespace { struct ShapeInferencePass : public mlir::PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) + StringRef getArgument() const override { return "toy-shape-inference"; } void runOnOperation() override { auto f = getOperation(); @@ -80,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/example/Ch6/parser/AST.cpp b/mlir/example/Ch6/parser/AST.cpp index 2546f2a..8416424 100644 --- a/mlir/example/Ch6/parser/AST.cpp +++ b/mlir/example/Ch6/parser/AST.cpp @@ -120,7 +120,7 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *litOrNum) { +static void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal if (auto *num = llvm::dyn_cast(litOrNum)) { llvm::errs() << num->getValue(); diff --git a/mlir/example/Ch6/toyc.cpp b/mlir/example/Ch6/toyc.cpp index c244b31..cfc10a7 100644 --- a/mlir/example/Ch6/toyc.cpp +++ b/mlir/example/Ch6/toyc.cpp @@ -11,7 +11,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "toy/AST.h" #include "toy/Dialect.h" #include "toy/Lexer.h" @@ -19,7 +21,7 @@ #include "toy/Parser.h" #include "toy/Passes.h" -#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" @@ -94,7 +96,8 @@ static cl::opt emitAction( static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. -std::unique_ptr parseInputFile(llvm::StringRef filename) { +static std::unique_ptr +parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { @@ -107,8 +110,8 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.parseModule(); } -int loadMLIR(mlir::MLIRContext &context, - mlir::OwningOpRef &module) { +static int loadMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { @@ -138,8 +141,8 @@ int loadMLIR(mlir::MLIRContext &context, return 0; } -int loadAndProcessMLIR(mlir::MLIRContext &context, - mlir::OwningOpRef &module) { +static int loadAndProcessMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) { if (int error = loadMLIR(context, module)) return error; @@ -194,7 +197,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, return 0; } -int dumpAST() { +static int dumpAST() { if (inputType == InputType::MLIR) { llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; return 5; @@ -208,7 +211,7 @@ int dumpAST() { return 0; } -int dumpLLVMIR(mlir::ModuleOp module) { +static int dumpLLVMIR(mlir::ModuleOp module) { // Register the translation to LLVM IR with the MLIR context. mlir::registerBuiltinDialectTranslation(*module->getContext()); mlir::registerLLVMDialectTranslation(*module->getContext()); @@ -252,7 +255,7 @@ int dumpLLVMIR(mlir::ModuleOp module) { return 0; } -int runJit(mlir::ModuleOp module) { +static int runJit(mlir::ModuleOp module) { // Initialize LLVM targets. llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); @@ -299,6 +302,7 @@ int main(int argc, char **argv) { // If we aren't dumping the AST, then we are compiling with/to MLIR. mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); + mlir::LLVM::registerInlinerInterface(registry); mlir::MLIRContext context(registry); // Load our Dialect in this MLIR Context. diff --git a/mlir/example/Ch7/include/toy/Lexer.h b/mlir/example/Ch7/include/toy/Lexer.h index a3fde91..f022c2f 100644 --- a/mlir/example/Ch7/include/toy/Lexer.h +++ b/mlir/example/Ch7/include/toy/Lexer.h @@ -15,6 +15,7 @@ #include "llvm/ADT/StringRef.h" +#include #include #include diff --git a/mlir/example/Ch7/include/toy/Ops.td b/mlir/example/Ch7/include/toy/Ops.td index cfd6859..9151396 100644 --- a/mlir/example/Ch7/include/toy/Ops.td +++ b/mlir/example/Ch7/include/toy/Ops.td @@ -93,7 +93,7 @@ def ConstantOp : Toy_Op<"constant", // Add custom build methods for the constant operation. These method populates // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create(...)`. + // using `ConstantOp::create(builder, ...)`. let builders = [ // Build a constant with a given constant tensor value. OpBuilder<(ins "DenseElementsAttr":$value), [{ @@ -237,7 +237,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType or // StructType. @@ -250,7 +255,8 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ - OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + OpBuilder<(ins "Type":$result_type, "StringRef":$callee, + "ArrayRef":$arguments)> ]; } @@ -354,7 +360,7 @@ def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, // Allow building a ReturnOp with no return operand. let builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> + OpBuilder<(ins), [{ build($_builder, $_state, {}); }]> ]; // Provide extra utility definitions on the c++ operation class definition. diff --git a/mlir/example/Ch7/mlir/Dialect.cpp b/mlir/example/Ch7/mlir/Dialect.cpp index 7e030ff..4d2f063 100644 --- a/mlir/example/Ch7/mlir/Dialect.cpp +++ b/mlir/example/Ch7/mlir/Dialect.cpp @@ -97,7 +97,7 @@ struct ToyInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); + return CastOp::create(builder, conversionLoc, resultType, input); } }; @@ -350,9 +350,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { //===----------------------------------------------------------------------===// void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { - // Generic call always returns an unranked Tensor initially. - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + mlir::Type resultType, StringRef callee, + ArrayRef arguments) { + state.addTypes(resultType); state.addOperands(arguments); state.addAttribute("callee", mlir::SymbolRefAttr::get(builder.getContext(), callee)); @@ -367,7 +367,7 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() { /// Set the callee for the generic call operation, this is required by the call /// interface. void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); + (*this)->setAttr("callee", cast(callee)); } /// Get the argument operands to the called function, this is required by the @@ -429,7 +429,8 @@ llvm::LogicalResult ReturnOp::verify() { auto resultType = results.front(); // Check that the result type of the function matches the operand type. - if (inputType == resultType || llvm::isa(inputType) || + if (inputType == resultType || + llvm::isa(inputType) || llvm::isa(resultType)) return mlir::success(); @@ -657,8 +658,8 @@ mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, mlir::Type type, mlir::Location loc) { if (llvm::isa(type)) - return builder.create(loc, type, - llvm::cast(value)); - return builder.create(loc, type, - llvm::cast(value)); + return StructConstantOp::create(builder, loc, type, + llvm::cast(value)); + return ConstantOp::create(builder, loc, type, + llvm::cast(value)); } diff --git a/mlir/example/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/example/Ch7/mlir/LowerToAffineLoops.cpp index 7413214..cbe4236 100644 --- a/mlir/example/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/example/Ch7/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -55,7 +55,7 @@ static MemRefType convertTensorToMemRef(RankedTensorType type) { /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter) { - auto alloc = rewriter.create(loc, type); + auto alloc = memref::AllocOp::create(rewriter, loc, type); // Make sure to allocate at the beginning of the block. auto *parentBlock = alloc->getBlock(); @@ -63,21 +63,19 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, // Make sure to deallocate this alloc at the end of the block. This is fine // as toy functions have no control flow. - auto dealloc = rewriter.create(loc, alloc); + auto dealloc = memref::DeallocOp::create(rewriter, loc, alloc); dealloc->moveBefore(&parentBlock->back()); return alloc; } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,12 +93,12 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, + // Call the processing function with the rewriter // and the loop induction variables. This function will return the value // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); - nestedBuilder.create(loc, valueToStore, alloc, - ivs); + Value valueToStore = processIteration(nestedBuilder, ivs); + affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, + ivs); }); // Replace this operation with the generated alloc. @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = builder.create( - loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = builder.create( - loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return builder.create(loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering; using MulOpLowering = BinaryOpLowering; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -174,11 +165,11 @@ struct ConstantOpLowering : public OpRewritePattern { if (!valueShape.empty()) { for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) constantIndices.push_back( - rewriter.create(loc, i)); + arith::ConstantIndexOp::create(rewriter, loc, i)); } else { // This is the case of a tensor of rank 0. constantIndices.push_back( - rewriter.create(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); } // The constant operation represents a multi-dimensional constant, so we @@ -191,9 +182,9 @@ struct ConstantOpLowering : public OpRewritePattern { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. if (dimension == valueShape.size()) { - rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, - llvm::ArrayRef(indices)); + affine::AffineStoreOp::create( + rewriter, loc, arith::ConstantOp::create(rewriter, loc, *valueIt++), + alloc, llvm::ArrayRef(indices)); return; } @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern { @@ -238,8 +229,8 @@ struct FuncOpLowering : public OpConversionPattern { } // Create a new non-toy function, with the same region. - auto func = rewriter.create(op.getLoc(), op.getName(), - op.getFunctionType()); + auto func = mlir::func::FuncOp::create(rewriter, op.getLoc(), op.getName(), + op.getFunctionType()); rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); rewriter.eraseOp(op); return success(); @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); - return builder.create(loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; @@ -328,6 +312,7 @@ namespace { struct ToyToAffineLoweringPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToAffineLoweringPass) + StringRef getArgument() const override { return "toy-to-affine"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert { public: - explicit PrintOpLowering(MLIRContext *context) - : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *context = rewriter.getContext(); auto memRefType = llvm::cast((*op->operand_type_begin())); @@ -86,13 +85,13 @@ class PrintOpLowering : public ConversionPattern { // Create a loop for each of the dimensions within the shape. SmallVector loopIvs; for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto upperBound = - rewriter.create(loc, memRefShape[i]); - auto step = rewriter.create(loc, 1); + arith::ConstantIndexOp::create(rewriter, loc, memRefShape[i]); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto loop = - rewriter.create(loc, lowerBound, upperBound, step); - for (Operation &nested : *loop.getBody()) + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); + for (Operation &nested : make_early_inc_range(*loop.getBody())) rewriter.eraseOp(&nested); loopIvs.push_back(loop.getInductionVar()); @@ -101,19 +100,17 @@ class PrintOpLowering : public ConversionPattern { // Insert a newline after each of the inner dimensions of the shape. if (i != e - 1) - rewriter.create(loc, getPrintfType(context), printfRef, - newLineCst); - rewriter.create(loc); + LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, + newLineCst); + scf::YieldOp::create(rewriter, loc); rewriter.setInsertionPointToStart(loop.getBody()); } // Generate a call to printf for the current element of the loop. - auto printOp = cast(op); auto elementLoad = - rewriter.create(loc, printOp.getInput(), loopIvs); - rewriter.create( - loc, getPrintfType(context), printfRef, - ArrayRef({formatSpecifierCst, elementLoad})); + memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs); + LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, + ArrayRef({formatSpecifierCst, elementLoad})); // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); @@ -142,8 +139,8 @@ class PrintOpLowering : public ConversionPattern { // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), "printf", - getPrintfType(context)); + LLVM::LLVMFuncOp::create(rewriter, module.getLoc(), "printf", + getPrintfType(context)); return SymbolRefAttr::get(context, "printf"); } @@ -159,19 +156,19 @@ class PrintOpLowering : public ConversionPattern { builder.setInsertionPointToStart(module.getBody()); auto type = LLVM::LLVMArrayType::get( IntegerType::get(builder.getContext(), 8), value.size()); - global = builder.create(loc, type, /*isConstant=*/true, - LLVM::Linkage::Internal, name, - builder.getStringAttr(value), - /*alignment=*/0); + global = LLVM::GlobalOp::create(builder, loc, type, /*isConstant=*/true, + LLVM::Linkage::Internal, name, + builder.getStringAttr(value), + /*alignment=*/0); } // Get the pointer to the first character in the global string. - Value globalPtr = builder.create(loc, global); - Value cst0 = builder.create(loc, builder.getI64Type(), - builder.getIndexAttr(0)); - return builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), - globalPtr, ArrayRef({cst0, cst0})); + Value globalPtr = LLVM::AddressOfOp::create(builder, loc, global); + Value cst0 = LLVM::ConstantOp::create(builder, loc, builder.getI64Type(), + builder.getIndexAttr(0)); + return LLVM::GEPOp::create( + builder, loc, LLVM::LLVMPointerType::get(builder.getContext()), + global.getType(), globalPtr, ArrayRef({cst0, cst0})); } }; } // namespace @@ -184,6 +181,7 @@ namespace { struct ToyToLLVMLoweringPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToLLVMLoweringPass) + StringRef getArgument() const override { return "toy-to-llvm"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -220,6 +218,7 @@ void ToyToLLVMLoweringPass::runOnOperation() { mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); populateFuncToLLVMConversionPatterns(typeConverter, patterns); // The only remaining operation to lower from the `toy` dialect, is the diff --git a/mlir/example/Ch7/mlir/MLIRGen.cpp b/mlir/example/Ch7/mlir/MLIRGen.cpp index 090e5ff..7313324 100644 --- a/mlir/example/Ch7/mlir/MLIRGen.cpp +++ b/mlir/example/Ch7/mlir/MLIRGen.cpp @@ -182,9 +182,9 @@ class MLIRGenImpl { return nullptr; argTypes.push_back(type); } - auto funcType = builder.getFunctionType(argTypes, std::nullopt); - return builder.create(location, proto.getName(), - funcType); + auto funcType = builder.getFunctionType(argTypes, /*results=*/{}); + return mlir::toy::FuncOp::create(builder, location, proto.getName(), + funcType); } /// Emit a new function and add it to the MLIR module. @@ -227,7 +227,7 @@ class MLIRGenImpl { if (!entryBlock.empty()) returnOp = dyn_cast(entryBlock.back()); if (!returnOp) { - builder.create(loc(funcAST.getProto()->loc())); + ReturnOp::create(builder, loc(funcAST.getProto()->loc())); } else if (returnOp.hasOperand()) { // Otherwise, if this return operation has an operand then add a result to // the function. @@ -333,7 +333,7 @@ class MLIRGenImpl { emitError(location, "invalid access into struct expression"); return nullptr; } - return builder.create(location, lhs, *accessIndex); + return StructAccessOp::create(builder, location, lhs, *accessIndex); } // Otherwise, this is a normal binary op. @@ -345,9 +345,9 @@ class MLIRGenImpl { // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder.create(location, lhs, rhs); + return AddOp::create(builder, location, lhs, rhs); case '*': - return builder.create(location, lhs, rhs); + return MulOp::create(builder, location, lhs, rhs); } emitError(location, "invalid binary operator '") << binop.getOp() << "'"; @@ -378,8 +378,8 @@ class MLIRGenImpl { } // Otherwise, this return operation has zero operands. - builder.create(location, - expr ? ArrayRef(expr) : ArrayRef()); + ReturnOp::create(builder, location, + expr ? ArrayRef(expr) : ArrayRef()); return mlir::success(); } @@ -405,8 +405,7 @@ class MLIRGenImpl { // The attribute is a vector with a floating point value per element // (number) in the array, see `collectData()` below for more details. std::vector data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies())); + data.reserve(llvm::product_of(lit.getDims())); collectData(lit, data); // The type of this attribute is tensor of 64-bit floating-point with the @@ -441,10 +440,10 @@ class MLIRGenImpl { for (auto &var : lit.getValues()) { if (auto *number = llvm::dyn_cast(var.get())) { attrElements.push_back(getConstantAttr(*number)); - typeElements.push_back(getType(std::nullopt)); + typeElements.push_back(getType(/*shape=*/{})); } else if (auto *lit = llvm::dyn_cast(var.get())) { attrElements.push_back(getConstantAttr(*lit)); - typeElements.push_back(getType(std::nullopt)); + typeElements.push_back(getType(/*shape=*/{})); } else { auto *structLit = llvm::cast(var.get()); auto attrTypePair = getConstantAttr(*structLit); @@ -464,7 +463,7 @@ class MLIRGenImpl { // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` // method. - return builder.create(loc(lit.loc()), type, dataAttribute); + return ConstantOp::create(builder, loc(lit.loc()), type, dataAttribute); } /// Emit a struct literal. It will be emitted as an array of @@ -477,7 +476,8 @@ class MLIRGenImpl { // Build the MLIR op `toy.struct_constant`. This invokes the // `StructConstantOp::build` method. - return builder.create(loc(lit.loc()), dataType, dataAttr); + return StructConstantOp::create(builder, loc(lit.loc()), dataType, + dataAttr); } /// Recursive helper function to accumulate the data that compose an array @@ -522,7 +522,7 @@ class MLIRGenImpl { "does not accept multiple arguments"); return nullptr; } - return builder.create(location, operands[0]); + return TransposeOp::create(builder, location, operands[0]); } // Otherwise this is a call to a user-defined function. Calls to @@ -534,9 +534,9 @@ class MLIRGenImpl { return nullptr; } mlir::toy::FuncOp calledFunc = calledFuncIt->second; - return builder.create( - location, calledFunc.getFunctionType().getResult(0), - mlir::SymbolRefAttr::get(builder.getContext(), callee), operands); + return GenericCallOp::create(builder, location, + calledFunc.getFunctionType().getResult(0), + callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: @@ -546,13 +546,13 @@ class MLIRGenImpl { if (!arg) return mlir::failure(); - builder.create(loc(call.loc()), arg); + PrintOp::create(builder, loc(call.loc()), arg); return mlir::success(); } /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value mlirGen(NumberExprAST &num) { - return builder.create(loc(num.loc()), num.getValue()); + return ConstantOp::create(builder, loc(num.loc()), num.getValue()); } /// Dispatch codegen for the right expression subclass using RTTI. @@ -614,8 +614,8 @@ class MLIRGenImpl { // declared with specific shape, we emit a "reshape" operation. It will // get optimized out later as needed. } else if (!varType.shape.empty()) { - value = builder.create(loc(vardecl.loc()), - getType(varType.shape), value); + value = ReshapeOp::create(builder, loc(vardecl.loc()), + getType(varType.shape), value); } // Register the value in the symbol table. diff --git a/mlir/example/Ch7/mlir/ShapeInferencePass.cpp b/mlir/example/Ch7/mlir/ShapeInferencePass.cpp index a9e995e..a552e1f 100644 --- a/mlir/example/Ch7/mlir/ShapeInferencePass.cpp +++ b/mlir/example/Ch7/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include @@ -55,6 +55,7 @@ namespace { struct ShapeInferencePass : public mlir::PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) + StringRef getArgument() const override { return "toy-shape-inference"; } void runOnOperation() override { auto f = getOperation(); @@ -80,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/example/Ch7/parser/AST.cpp b/mlir/example/Ch7/parser/AST.cpp index e38a743..aa2c784 100644 --- a/mlir/example/Ch7/parser/AST.cpp +++ b/mlir/example/Ch7/parser/AST.cpp @@ -123,7 +123,7 @@ void ASTDumper::dump(NumberExprAST *num) { /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *litOrNum) { +static void printLitHelper(ExprAST *litOrNum) { // Inside a literal expression we can have either a number or another literal if (auto *num = llvm::dyn_cast(litOrNum)) { llvm::errs() << num->getValue(); diff --git a/mlir/example/Ch7/toyc.cpp b/mlir/example/Ch7/toyc.cpp index fea5679..ffd94bc 100644 --- a/mlir/example/Ch7/toyc.cpp +++ b/mlir/example/Ch7/toyc.cpp @@ -11,7 +11,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "toy/AST.h" #include "toy/Dialect.h" #include "toy/Lexer.h" @@ -19,7 +21,7 @@ #include "toy/Parser.h" #include "toy/Passes.h" -#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" @@ -94,7 +96,8 @@ static cl::opt emitAction( static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); /// Returns a Toy AST resulting from parsing the file or a nullptr on error. -std::unique_ptr parseInputFile(llvm::StringRef filename) { +static std::unique_ptr +parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { @@ -107,8 +110,8 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.parseModule(); } -int loadMLIR(mlir::MLIRContext &context, - mlir::OwningOpRef &module) { +static int loadMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) { // Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { @@ -138,8 +141,8 @@ int loadMLIR(mlir::MLIRContext &context, return 0; } -int loadAndProcessMLIR(mlir::MLIRContext &context, - mlir::OwningOpRef &module) { +static int loadAndProcessMLIR(mlir::MLIRContext &context, + mlir::OwningOpRef &module) { if (int error = loadMLIR(context, module)) return error; @@ -195,7 +198,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context, return 0; } -int dumpAST() { +static int dumpAST() { if (inputType == InputType::MLIR) { llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; return 5; @@ -209,7 +212,7 @@ int dumpAST() { return 0; } -int dumpLLVMIR(mlir::ModuleOp module) { +static int dumpLLVMIR(mlir::ModuleOp module) { // Register the translation to LLVM IR with the MLIR context. mlir::registerBuiltinDialectTranslation(*module->getContext()); mlir::registerLLVMDialectTranslation(*module->getContext()); @@ -253,7 +256,7 @@ int dumpLLVMIR(mlir::ModuleOp module) { return 0; } -int runJit(mlir::ModuleOp module) { +static int runJit(mlir::ModuleOp module) { // Initialize LLVM targets. llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); @@ -300,6 +303,7 @@ int main(int argc, char **argv) { // If we aren't dumping the AST, then we are compiling with/to MLIR. mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); + mlir::LLVM::registerInlinerInterface(registry); mlir::MLIRContext context(registry); // Load our Dialect in this MLIR Context. diff --git a/mlir/example/Ch8/CMakeLists.txt b/mlir/example/Ch8/CMakeLists.txt deleted file mode 100644 index 0bfd338..0000000 --- a/mlir/example/Ch8/CMakeLists.txt +++ /dev/null @@ -1,48 +0,0 @@ -# For a better template to copy, see examples/standalone -include_directories(include) -add_subdirectory(include) - -set(LLVM_LINK_COMPONENTS Core Support nativecodegen OrcJIT) - -set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) -mlir_tablegen(ToyCombine.inc -gen-rewriters) -add_public_tablegen_target(ToyCh8CombineIncGen) - -add_executable( - mlir-example-ch8 - toyc.cpp - parser/AST.cpp - mlir/MLIRGen.cpp - mlir/Dialect.cpp - mlir/LowerToAffineLoops.cpp - mlir/LowerToLLVM.cpp - mlir/ShapeInferencePass.cpp - mlir/ToyCombine.cpp) - -add_dependencies(mlir-example-ch8 ToyCh8ShapeInferenceInterfaceIncGen - ToyCh8OpsIncGen ToyCh8CombineIncGen) - -include_directories(${CMAKE_CURRENT_BINARY_DIR}) -include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) -target_link_libraries( - mlir-example-ch8 - PRIVATE ${dialect_libs} - ${conversion_libs} - ${extension_libs} - MLIRAnalysis - MLIRBuiltinToLLVMIRTranslation - MLIRCallInterfaces - MLIRCastInterfaces - MLIRExecutionEngine - MLIRIR - MLIRLLVMCommonConversion - MLIRLLVMToLLVMIRTranslation - MLIRMemRefDialect - MLIRParser - MLIRPass - MLIRSideEffectInterfaces - MLIRTargetLLVMIRExport - MLIRTransforms) diff --git a/mlir/example/Ch8/include/CMakeLists.txt b/mlir/example/Ch8/include/CMakeLists.txt deleted file mode 100644 index 37c89d0..0000000 --- a/mlir/example/Ch8/include/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(toy) diff --git a/mlir/example/Ch8/include/toy/AST.h b/mlir/example/Ch8/include/toy/AST.h deleted file mode 100644 index 4827865..0000000 --- a/mlir/example/Ch8/include/toy/AST.h +++ /dev/null @@ -1,313 +0,0 @@ -//===- AST.h - Node definition for the Toy AST ----------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the AST for the Toy language. It is optimized for -// simplicity, not efficiency. The AST forms a tree structure where each node -// references its children using std::unique_ptr<>. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_AST_H -#define TOY_AST_H - -#include "toy/Lexer.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include -#include -#include - -namespace toy { - -/// A variable type with either name or shape information. -struct VarType { - std::string name; - std::vector shape; -}; - -/// Base class for all expression nodes. -class ExprAST { -public: - enum ExprASTKind { - Expr_VarDecl, - Expr_Return, - Expr_Num, - Expr_Literal, - Expr_StructLiteral, - Expr_Var, - Expr_BinOp, - Expr_Call, - Expr_Print, - }; - - ExprAST(ExprASTKind kind, Location location) - : kind(kind), location(std::move(location)) {} - virtual ~ExprAST() = default; - - ExprASTKind getKind() const { return kind; } - - const Location &loc() { return location; } - -private: - const ExprASTKind kind; - Location location; -}; - -/// A block-list of expressions. -using ExprASTList = std::vector>; - -/// Expression class for numeric literals like "1.0". -class NumberExprAST : public ExprAST { - double val; - -public: - NumberExprAST(Location loc, double val) - : ExprAST(Expr_Num, std::move(loc)), val(val) {} - - double getValue() { return val; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; } -}; - -/// Expression class for a literal value. -class LiteralExprAST : public ExprAST { - std::vector> values; - std::vector dims; - -public: - LiteralExprAST(Location loc, std::vector> values, - std::vector dims) - : ExprAST(Expr_Literal, std::move(loc)), values(std::move(values)), - dims(std::move(dims)) {} - - llvm::ArrayRef> getValues() { return values; } - llvm::ArrayRef getDims() { return dims; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; } -}; - -/// Expression class for a literal struct value. -class StructLiteralExprAST : public ExprAST { - std::vector> values; - -public: - StructLiteralExprAST(Location loc, - std::vector> values) - : ExprAST(Expr_StructLiteral, std::move(loc)), values(std::move(values)) { - } - - llvm::ArrayRef> getValues() { return values; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { - return c->getKind() == Expr_StructLiteral; - } -}; - -/// Expression class for referencing a variable, like "a". -class VariableExprAST : public ExprAST { - std::string name; - -public: - VariableExprAST(Location loc, llvm::StringRef name) - : ExprAST(Expr_Var, std::move(loc)), name(name) {} - - llvm::StringRef getName() { return name; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; } -}; - -/// Expression class for defining a variable. -class VarDeclExprAST : public ExprAST { - std::string name; - VarType type; - std::unique_ptr initVal; - -public: - VarDeclExprAST(Location loc, llvm::StringRef name, VarType type, - std::unique_ptr initVal = nullptr) - : ExprAST(Expr_VarDecl, std::move(loc)), name(name), - type(std::move(type)), initVal(std::move(initVal)) {} - - llvm::StringRef getName() { return name; } - ExprAST *getInitVal() { return initVal.get(); } - const VarType &getType() { return type; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; } -}; - -/// Expression class for a return operator. -class ReturnExprAST : public ExprAST { - std::optional> expr; - -public: - ReturnExprAST(Location loc, std::optional> expr) - : ExprAST(Expr_Return, std::move(loc)), expr(std::move(expr)) {} - - std::optional getExpr() { - if (expr.has_value()) - return expr->get(); - return std::nullopt; - } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; } -}; - -/// Expression class for a binary operator. -class BinaryExprAST : public ExprAST { - char op; - std::unique_ptr lhs, rhs; - -public: - char getOp() { return op; } - ExprAST *getLHS() { return lhs.get(); } - ExprAST *getRHS() { return rhs.get(); } - - BinaryExprAST(Location loc, char op, std::unique_ptr lhs, - std::unique_ptr rhs) - : ExprAST(Expr_BinOp, std::move(loc)), op(op), lhs(std::move(lhs)), - rhs(std::move(rhs)) {} - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; } -}; - -/// Expression class for function calls. -class CallExprAST : public ExprAST { - std::string callee; - std::vector> args; - -public: - CallExprAST(Location loc, const std::string &callee, - std::vector> args) - : ExprAST(Expr_Call, std::move(loc)), callee(callee), - args(std::move(args)) {} - - llvm::StringRef getCallee() { return callee; } - llvm::ArrayRef> getArgs() { return args; } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; } -}; - -/// Expression class for builtin print calls. -class PrintExprAST : public ExprAST { - std::unique_ptr arg; - -public: - PrintExprAST(Location loc, std::unique_ptr arg) - : ExprAST(Expr_Print, std::move(loc)), arg(std::move(arg)) {} - - ExprAST *getArg() { return arg.get(); } - - /// LLVM style RTTI - static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; } -}; - -/// This class represents the "prototype" for a function, which captures its -/// name, and its argument names (thus implicitly the number of arguments the -/// function takes). -class PrototypeAST { - Location location; - std::string name; - std::vector> args; - -public: - PrototypeAST(Location location, const std::string &name, - std::vector> args) - : location(std::move(location)), name(name), args(std::move(args)) {} - - const Location &loc() { return location; } - llvm::StringRef getName() const { return name; } - llvm::ArrayRef> getArgs() { return args; } -}; - -/// This class represents a top level record in a module. -class RecordAST { -public: - enum RecordASTKind { - Record_Function, - Record_Struct, - }; - - RecordAST(RecordASTKind kind) : kind(kind) {} - virtual ~RecordAST() = default; - - RecordASTKind getKind() const { return kind; } - -private: - const RecordASTKind kind; -}; - -/// This class represents a function definition itself. -class FunctionAST : public RecordAST { - std::unique_ptr proto; - std::unique_ptr body; - -public: - FunctionAST(std::unique_ptr proto, - std::unique_ptr body) - : RecordAST(Record_Function), proto(std::move(proto)), - body(std::move(body)) {} - PrototypeAST *getProto() { return proto.get(); } - ExprASTList *getBody() { return body.get(); } - - /// LLVM style RTTI - static bool classof(const RecordAST *r) { - return r->getKind() == Record_Function; - } -}; - -/// This class represents a struct definition. -class StructAST : public RecordAST { - Location location; - std::string name; - std::vector> variables; - -public: - StructAST(Location location, const std::string &name, - std::vector> variables) - : RecordAST(Record_Struct), location(std::move(location)), name(name), - variables(std::move(variables)) {} - - const Location &loc() { return location; } - llvm::StringRef getName() const { return name; } - llvm::ArrayRef> getVariables() { - return variables; - } - - /// LLVM style RTTI - static bool classof(const RecordAST *r) { - return r->getKind() == Record_Struct; - } -}; - -/// This class represents a list of functions to be processed together -class ModuleAST { - std::vector> records; - -public: - ModuleAST(std::vector> records) - : records(std::move(records)) {} - - auto begin() { return records.begin(); } - auto end() { return records.end(); } -}; - -void dump(ModuleAST &); - -} // namespace toy - -#endif // TOY_AST_H diff --git a/mlir/example/Ch8/include/toy/CMakeLists.txt b/mlir/example/Ch8/include/toy/CMakeLists.txt deleted file mode 100644 index ecbf653..0000000 --- a/mlir/example/Ch8/include/toy/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -# Most dialects should use add_mlir_dialect(). See examples/standalone. -set(LLVM_TARGET_DEFINITIONS Ops.td) -mlir_tablegen(Ops.h.inc -gen-op-decls) -mlir_tablegen(Ops.cpp.inc -gen-op-defs) -mlir_tablegen(Dialect.h.inc -gen-dialect-decls) -mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) -add_public_tablegen_target(ToyCh8OpsIncGen) - -# Most dialects should use add_mlir_interfaces(). -set(LLVM_TARGET_DEFINITIONS ShapeInferenceInterface.td) -mlir_tablegen(ShapeInferenceOpInterfaces.h.inc -gen-op-interface-decls) -mlir_tablegen(ShapeInferenceOpInterfaces.cpp.inc -gen-op-interface-defs) -add_public_tablegen_target(ToyCh8ShapeInferenceInterfaceIncGen) diff --git a/mlir/example/Ch8/include/toy/Dialect.h b/mlir/example/Ch8/include/toy/Dialect.h deleted file mode 100644 index 64094c3..0000000 --- a/mlir/example/Ch8/include/toy/Dialect.h +++ /dev/null @@ -1,82 +0,0 @@ -//===- Dialect.h - Dialect definition for the Toy IR ----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the IR Dialect for the Toy language. -// See docs/Tutorials/Toy/Ch-2.md for more information. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TUTORIAL_TOY_DIALECT_H_ -#define MLIR_TUTORIAL_TOY_DIALECT_H_ - -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/CastInterfaces.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "toy/ShapeInferenceInterface.h" - -namespace mlir { -namespace toy { -namespace detail { -struct StructTypeStorage; -} // namespace detail -} // namespace toy -} // namespace mlir - -/// Include the auto-generated header file containing the declaration of the toy -/// dialect. -#include "toy/Dialect.h.inc" - -//===----------------------------------------------------------------------===// -// Toy Operations -//===----------------------------------------------------------------------===// - -/// Include the auto-generated header file containing the declarations of the -/// toy operations. -#define GET_OP_CLASSES -#include "toy/Ops.h.inc" - -namespace mlir { -namespace toy { - -//===----------------------------------------------------------------------===// -// Toy Types -//===----------------------------------------------------------------------===// - -/// This class defines the Toy struct type. It represents a collection of -/// element types. All derived types in MLIR must inherit from the CRTP class -/// 'Type::TypeBase'. It takes as template parameters the concrete type -/// (StructType), the base class to use (Type), and the storage class -/// (StructTypeStorage). -class StructType : public mlir::Type::TypeBase { -public: - /// Inherit some necessary constructors from 'TypeBase'. - using Base::Base; - - /// Create an instance of a `StructType` with the given element types. There - /// *must* be atleast one element type. - static StructType get(llvm::ArrayRef elementTypes); - - /// Returns the element types of this struct type. - llvm::ArrayRef getElementTypes(); - - /// Returns the number of element type held by this struct. - size_t getNumElementTypes() { return getElementTypes().size(); } - - /// The name of this struct type. - static constexpr StringLiteral name = "toy.struct"; -}; -} // namespace toy -} // namespace mlir - -#endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/example/Ch8/include/toy/Lexer.h b/mlir/example/Ch8/include/toy/Lexer.h deleted file mode 100644 index a3fde91..0000000 --- a/mlir/example/Ch8/include/toy/Lexer.h +++ /dev/null @@ -1,235 +0,0 @@ -//===- Lexer.h - Lexer for the Toy language -------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a simple Lexer for the Toy language. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_LEXER_H -#define TOY_LEXER_H - -#include "llvm/ADT/StringRef.h" - -#include -#include - -namespace toy { - -/// Structure definition a location in a file. -struct Location { - std::shared_ptr file; ///< filename. - int line; ///< line number. - int col; ///< column number. -}; - -// List of Token returned by the lexer. -enum Token : int { - tok_semicolon = ';', - tok_parenthese_open = '(', - tok_parenthese_close = ')', - tok_bracket_open = '{', - tok_bracket_close = '}', - tok_sbracket_open = '[', - tok_sbracket_close = ']', - - tok_eof = -1, - - // commands - tok_return = -2, - tok_var = -3, - tok_def = -4, - tok_struct = -5, - - // primary - tok_identifier = -6, - tok_number = -7, -}; - -/// The Lexer is an abstract base class providing all the facilities that the -/// Parser expects. It goes through the stream one token at a time and keeps -/// track of the location in the file for debugging purpose. -/// It relies on a subclass to provide a `readNextLine()` method. The subclass -/// can proceed by reading the next line from the standard input or from a -/// memory mapped file. -class Lexer { -public: - /// Create a lexer for the given filename. The filename is kept only for - /// debugging purpose (attaching a location to a Token). - Lexer(std::string filename) - : lastLocation( - {std::make_shared(std::move(filename)), 0, 0}) {} - virtual ~Lexer() = default; - - /// Look at the current token in the stream. - Token getCurToken() { return curTok; } - - /// Move to the next token in the stream and return it. - Token getNextToken() { return curTok = getTok(); } - - /// Move to the next token in the stream, asserting on the current token - /// matching the expectation. - void consume(Token tok) { - assert(tok == curTok && "consume Token mismatch expectation"); - getNextToken(); - } - - /// Return the current identifier (prereq: getCurToken() == tok_identifier) - llvm::StringRef getId() { - assert(curTok == tok_identifier); - return identifierStr; - } - - /// Return the current number (prereq: getCurToken() == tok_number) - double getValue() { - assert(curTok == tok_number); - return numVal; - } - - /// Return the location for the beginning of the current token. - Location getLastLocation() { return lastLocation; } - - // Return the current line in the file. - int getLine() { return curLineNum; } - - // Return the current column in the file. - int getCol() { return curCol; } - -private: - /// Delegate to a derived class fetching the next line. Returns an empty - /// string to signal end of file (EOF). Lines are expected to always finish - /// with "\n" - virtual llvm::StringRef readNextLine() = 0; - - /// Return the next character from the stream. This manages the buffer for the - /// current line and request the next line buffer to the derived class as - /// needed. - int getNextChar() { - // The current line buffer should not be empty unless it is the end of file. - if (curLineBuffer.empty()) - return EOF; - ++curCol; - auto nextchar = curLineBuffer.front(); - curLineBuffer = curLineBuffer.drop_front(); - if (curLineBuffer.empty()) - curLineBuffer = readNextLine(); - if (nextchar == '\n') { - ++curLineNum; - curCol = 0; - } - return nextchar; - } - - /// Return the next token from standard input. - Token getTok() { - // Skip any whitespace. - while (isspace(lastChar)) - lastChar = Token(getNextChar()); - - // Save the current location before reading the token characters. - lastLocation.line = curLineNum; - lastLocation.col = curCol; - - // Identifier: [a-zA-Z][a-zA-Z0-9_]* - if (isalpha(lastChar)) { - identifierStr = (char)lastChar; - while (isalnum((lastChar = Token(getNextChar()))) || lastChar == '_') - identifierStr += (char)lastChar; - - if (identifierStr == "return") - return tok_return; - if (identifierStr == "def") - return tok_def; - if (identifierStr == "struct") - return tok_struct; - if (identifierStr == "var") - return tok_var; - return tok_identifier; - } - - // Number: [0-9] ([0-9.])* - if (isdigit(lastChar)) { - std::string numStr; - do { - numStr += lastChar; - lastChar = Token(getNextChar()); - } while (isdigit(lastChar) || lastChar == '.'); - - numVal = strtod(numStr.c_str(), nullptr); - return tok_number; - } - - if (lastChar == '#') { - // Comment until end of line. - do { - lastChar = Token(getNextChar()); - } while (lastChar != EOF && lastChar != '\n' && lastChar != '\r'); - - if (lastChar != EOF) - return getTok(); - } - - // Check for end of file. Don't eat the EOF. - if (lastChar == EOF) - return tok_eof; - - // Otherwise, just return the character as its ascii value. - Token thisChar = Token(lastChar); - lastChar = Token(getNextChar()); - return thisChar; - } - - /// The last token read from the input. - Token curTok = tok_eof; - - /// Location for `curTok`. - Location lastLocation; - - /// If the current Token is an identifier, this string contains the value. - std::string identifierStr; - - /// If the current Token is a number, this contains the value. - double numVal = 0; - - /// The last value returned by getNextChar(). We need to keep it around as we - /// always need to read ahead one character to decide when to end a token and - /// we can't put it back in the stream after reading from it. - Token lastChar = Token(' '); - - /// Keep track of the current line number in the input stream - int curLineNum = 0; - - /// Keep track of the current column number in the input stream - int curCol = 0; - - /// Buffer supplied by the derived class on calls to `readNextLine()` - llvm::StringRef curLineBuffer = "\n"; -}; - -/// A lexer implementation operating on a buffer in memory. -class LexerBuffer final : public Lexer { -public: - LexerBuffer(const char *begin, const char *end, std::string filename) - : Lexer(std::move(filename)), current(begin), end(end) {} - -private: - /// Provide one line at a time to the Lexer, return an empty string when - /// reaching the end of the buffer. - llvm::StringRef readNextLine() override { - auto *begin = current; - while (current <= end && *current && *current != '\n') - ++current; - if (current <= end && *current) - ++current; - llvm::StringRef result{begin, static_cast(current - begin)}; - return result; - } - const char *current, *end; -}; -} // namespace toy - -#endif // TOY_LEXER_H diff --git a/mlir/example/Ch8/include/toy/MLIRGen.h b/mlir/example/Ch8/include/toy/MLIRGen.h deleted file mode 100644 index fe9dbe5..0000000 --- a/mlir/example/Ch8/include/toy/MLIRGen.h +++ /dev/null @@ -1,35 +0,0 @@ -//===- MLIRGen.h - MLIR Generation from a Toy AST -------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares a simple interface to perform IR generation targeting MLIR -// from a Module AST for the Toy language. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_MLIRGEN_H -#define TOY_MLIRGEN_H - -#include - -namespace mlir { -class MLIRContext; -template -class OwningOpRef; -class ModuleOp; -} // namespace mlir - -namespace toy { -class ModuleAST; - -/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module -/// or nullptr on failure. -mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST); -} // namespace toy - -#endif // TOY_MLIRGEN_H diff --git a/mlir/example/Ch8/include/toy/Ops.td b/mlir/example/Ch8/include/toy/Ops.td deleted file mode 100644 index 4f76909..0000000 --- a/mlir/example/Ch8/include/toy/Ops.td +++ /dev/null @@ -1,482 +0,0 @@ -//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Defines the operations of the Toy dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_OPS -#define TOY_OPS - -include "mlir/Interfaces/FunctionInterfaces.td" -include "mlir/IR/SymbolInterfaces.td" -include "mlir/Interfaces/CallInterfaces.td" -include "mlir/Interfaces/CastInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "toy/ShapeInferenceInterface.td" - -// Provide a definition of the 'toy' dialect in the ODS framework so that we -// can define our operations. -def Toy_Dialect : Dialect { - let name = "toy"; - let cppNamespace = "::mlir::toy"; - - // We set this bit to generate a declaration of the `materializeConstant` - // method so that we can materialize constants for our toy operations. - let hasConstantMaterializer = 1; - - // We set this bit to generate the declarations for the dialect's type parsing - // and printing hooks. - let useDefaultTypePrinterParser = 1; - -} - -// Base class for toy dialect operations. This operation inherits from the base -// `Op` class in OpBase.td, and provides: -// * The parent dialect of the operation. -// * The mnemonic for the operation, or the name without the dialect prefix. -// * A list of traits for the operation. -class Toy_Op traits = []> : - Op; - -// Provide a definition for the Toy StructType for use in ODS. This allows for -// using StructType in a similar way to Tensor or MemRef. We use `DialectType` -// to demarcate the StructType as belonging to the Toy dialect. -def Toy_StructType : - DialectType($_self)">, - "Toy struct type">; - -// Provide a definition of the types that are used within the Toy dialect. -def Toy_Type : AnyTypeOf<[F64Tensor, Toy_StructType]>; - -//===----------------------------------------------------------------------===// -// Toy Operations -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// ConstantOp -//===----------------------------------------------------------------------===// - -// We define a toy operation by inheriting from our base 'Toy_Op' class above. -// Here we provide the mnemonic and a list of traits for the operation. The -// constant operation is marked as 'Pure' as it is a pure operation -// and may be removed if dead. -def ConstantOp : Toy_Op<"constant", - [ConstantLike, Pure, - DeclareOpInterfaceMethods]> { - // Provide a summary and description for this operation. This can be used to - // auto-generate documentation of the operations within our dialect. - let summary = "constant"; - let description = [{ - Constant operation turns a literal into an SSA value. The data is attached - to the operation as an attribute. For example: - - ```mlir - %0 = toy.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> - : tensor<2x3xf64> - ``` - }]; - - // The constant operation takes an attribute as the only input. - let arguments = (ins F64ElementsAttr:$value); - - // The constant operation returns a single value of TensorType. - let results = (outs F64Tensor); - - // Indicate that the operation has a custom parser and printer method. - let hasCustomAssemblyFormat = 1; - - // Add custom build methods for the constant operation. These method populates - // the `state` that MLIR uses to create operations, i.e. these are used when - // using `builder.create(...)`. - let builders = [ - // Build a constant with a given constant tensor value. - OpBuilder<(ins "DenseElementsAttr":$value), [{ - build($_builder, $_state, value.getType(), value); - }]>, - - // Build a constant with a given constant floating-point value. - OpBuilder<(ins "double":$value)> - ]; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; - - // Set the folder bit so that we can implement constant folders. - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// AddOp -//===----------------------------------------------------------------------===// - -def AddOp : Toy_Op<"add", - [Pure, DeclareOpInterfaceMethods]> { - let summary = "element-wise addition operation"; - let description = [{ - The "add" operation performs element-wise addition between two tensors. - The shapes of the tensor operands are expected to match. - }]; - - let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); - let results = (outs F64Tensor); - - // Indicate that the operation has a custom parser and printer method. - let hasCustomAssemblyFormat = 1; - - // Allow building an AddOp with from the two input operands. - let builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs)> - ]; -} - -//===----------------------------------------------------------------------===// -// CastOp -//===----------------------------------------------------------------------===// - -def CastOp : Toy_Op<"cast", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - Pure, - SameOperandsAndResultShape - ]> { - let summary = "shape cast operation"; - let description = [{ - The "cast" operation converts a tensor from one type to an equivalent type - without changing any data elements. The source and destination types must - both be tensor types with the same element type. If both are ranked, then - shape is required to match. The operation is invalid if converting to a - mismatching constant dimension. - }]; - - let arguments = (ins F64Tensor:$input); - let results = (outs F64Tensor:$output); - - let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; -} - -//===----------------------------------------------------------------------===// -// FuncOp -//===----------------------------------------------------------------------===// - -def FuncOp : Toy_Op<"func", [ - FunctionOpInterface, IsolatedFromAbove - ]> { - let summary = "user defined function operation"; - let description = [{ - The "toy.func" operation represents a user defined function. These are - callable SSA-region operations that contain toy computations. - - Example: - - ```mlir - toy.func @main() { - %0 = toy.constant dense<5.500000e+00> : tensor - %1 = toy.reshape(%0 : tensor) to tensor<2x2xf64> - toy.print %1 : tensor<2x2xf64> - toy.return - } - ``` - }]; - - let arguments = (ins - SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type, - OptionalAttr:$arg_attrs, - OptionalAttr:$res_attrs - ); - let regions = (region AnyRegion:$body); - - let builders = [OpBuilder<(ins - "StringRef":$name, "FunctionType":$type, - CArg<"ArrayRef", "{}">:$attrs) - >]; - let extraClassDeclaration = [{ - //===------------------------------------------------------------------===// - // FunctionOpInterface Methods - //===------------------------------------------------------------------===// - - /// Returns the argument types of this function. - ArrayRef getArgumentTypes() { return getFunctionType().getInputs(); } - - /// Returns the result types of this function. - ArrayRef getResultTypes() { return getFunctionType().getResults(); } - - Region *getCallableRegion() { return &getBody(); } - }]; - let hasCustomAssemblyFormat = 1; - let skipDefaultBuilders = 1; -} - -//===----------------------------------------------------------------------===// -// GenericCallOp -//===----------------------------------------------------------------------===// - -def GenericCallOp : Toy_Op<"generic_call", - [DeclareOpInterfaceMethods]> { - let summary = "generic call operation"; - let description = [{ - Generic calls represent calls to a user defined function that needs to - be specialized for the shape of its arguments. The callee name is attached - as a symbol reference via an attribute. The arguments list must match the - arguments expected by the callee. For example: - - ```mlir - %4 = toy.generic_call @my_func(%1, %3) - : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> - ``` - - This is only valid if a function named "my_func" exists and takes two - arguments. - }]; - - // The generic call operation takes a symbol reference attribute as the - // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); - - // The generic call operation returns a single value of TensorType or - // StructType. - let results = (outs Toy_Type); - - // Specialize assembly printing and parsing using a declarative format. - let assemblyFormat = [{ - $callee `(` $inputs `)` attr-dict `:` functional-type($inputs, results) - }]; - - // Add custom build methods for the generic call operation. - let builders = [ - OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> - ]; -} - -//===----------------------------------------------------------------------===// -// MulOp -//===----------------------------------------------------------------------===// - -def MulOp : Toy_Op<"mul", - [Pure, DeclareOpInterfaceMethods]> { - let summary = "element-wise multiplication operation"; - let description = [{ - The "mul" operation performs element-wise multiplication between two - tensors. The shapes of the tensor operands are expected to match. - }]; - - let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); - let results = (outs F64Tensor); - - // Indicate that the operation has a custom parser and printer method. - let hasCustomAssemblyFormat = 1; - - // Allow building a MulOp with from the two input operands. - let builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs)> - ]; -} - -//===----------------------------------------------------------------------===// -// PrintOp -//===----------------------------------------------------------------------===// - -def PrintOp : Toy_Op<"print"> { - let summary = "print operation"; - let description = [{ - The "print" builtin operation prints a given input tensor, and produces - no results. - }]; - - // The print operation takes an input tensor to print. - // We also allow a F64MemRef to enable interop during partial lowering. - let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input); - - let assemblyFormat = "$input attr-dict `:` type($input)"; -} - -//===----------------------------------------------------------------------===// -// ReshapeOp -//===----------------------------------------------------------------------===// - -def ReshapeOp : Toy_Op<"reshape", [Pure]> { - let summary = "tensor reshape operation"; - let description = [{ - Reshape operation is transforming its input tensor into a new tensor with - the same number of elements but different shapes. For example: - - ```mlir - %0 = toy.reshape (%arg1 : tensor<10xf64>) to tensor<5x2xf64> - ``` - }]; - - let arguments = (ins F64Tensor:$input); - - let assemblyFormat = [{ - `(` $input `:` type($input) `)` attr-dict `to` type(results) - }]; - - // Enable registering canonicalization patterns with this operation. - let hasCanonicalizer = 1; - - // We expect that the reshape operation returns a statically shaped tensor. - let results = (outs StaticShapeTensorOf<[F64]>); -} - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -def ReturnOp : Toy_Op<"return", [Pure, HasParent<"FuncOp">, - Terminator]> { - let summary = "return operation"; - let description = [{ - The "return" operation represents a return operation within a function. - The operation takes an optional operand and produces no results. - The operand type must match the signature of the function that contains - the operation. For example: - - ```mlir - toy.func @foo() -> tensor<2xf64> { - ... - toy.return %0 : tensor<2xf64> - } - ``` - }]; - - // The return operation takes an optional input operand to return. This - // value must match the return type of the enclosing function. - let arguments = (ins Variadic:$input); - - // The return operation only emits the input in the format if it is present. - let assemblyFormat = "($input^ `:` type($input))? attr-dict "; - - // Allow building a ReturnOp with no return operand. - let builders = [ - OpBuilder<(ins), [{ build($_builder, $_state, std::nullopt); }]> - ]; - - // Provide extra utility definitions on the c++ operation class definition. - let extraClassDeclaration = [{ - bool hasOperand() { return getNumOperands() != 0; } - }]; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; -} - -//===----------------------------------------------------------------------===// -// StructAccessOp -//===----------------------------------------------------------------------===// - -def StructAccessOp : Toy_Op<"struct_access", [Pure]> { - let summary = "struct access"; - let description = [{ - Access the Nth element of a value returning a struct type. - }]; - - let arguments = (ins Toy_StructType:$input, I64Attr:$index); - let results = (outs Toy_Type:$output); - - let assemblyFormat = [{ - $input `[` $index `]` attr-dict `:` type($input) `->` type($output) - }]; - - // Allow building a StructAccessOp with just a struct value and an index. - let builders = [ - OpBuilder<(ins "Value":$input, "size_t":$index)> - ]; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; - - // Set the folder bit so that we can fold constant accesses. - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// StructConstantOp -//===----------------------------------------------------------------------===// - -def StructConstantOp : Toy_Op<"struct_constant", [ConstantLike, Pure]> { - let summary = "struct constant"; - let description = [{ - Constant operation turns a literal struct value into an SSA value. The data - is attached to the operation as an attribute. The struct constant is encoded - as an array of other constant values. For example: - - ```mlir - %0 = toy.struct_constant [ - dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> - ] : !toy.struct> - ``` - }]; - - let arguments = (ins ArrayAttr:$value); - let results = (outs Toy_StructType:$output); - - let assemblyFormat = "$value attr-dict `:` type($output)"; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// TransposeOp -//===----------------------------------------------------------------------===// - -def TransposeOp : Toy_Op<"transpose", - [Pure, DeclareOpInterfaceMethods]> { - let summary = "transpose operation"; - - let arguments = (ins F64Tensor:$input); - let results = (outs F64Tensor); - - let assemblyFormat = [{ - `(` $input `:` type($input) `)` attr-dict `to` type(results) - }]; - - // Enable registering canonicalization patterns with this operation. - let hasCanonicalizer = 1; - - // Allow building a TransposeOp with from the input operand. - let builders = [ - OpBuilder<(ins "Value":$input)> - ]; - - // Indicate that additional verification for this operation is necessary. - let hasVerifier = 1; -} - -//===----------------------------------------------------------------------===// -// MatMul Op -//===----------------------------------------------------------------------===// - -def MatMulOp : Toy_Op<"matmul", - [Pure, DeclareOpInterfaceMethods, MemoryEffectsOpInterface]> { - let summary = "matrix multiplication operation"; - let description = [{ - The "matmul" operation performs Matrix multiplication between two - tensors. The shapes of the tensor operands are expected to match. - }]; - - let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); - let results = (outs Res, - MemAlloc]>:$output); - - let assemblyFormat = [{ - `(` $lhs `:` type($lhs) `,` $rhs `:` type($rhs) `)` attr-dict `to` type(results) - }]; - - // Allow building a MatMulOp with from the two input operands. - let builders = [ - OpBuilder<(ins "Value":$lhs, "Value":$rhs)> - ]; - - let hasVerifier = 1; -} - -#endif // TOY_OPS diff --git a/mlir/example/Ch8/include/toy/Parser.h b/mlir/example/Ch8/include/toy/Parser.h deleted file mode 100644 index 101b03d..0000000 --- a/mlir/example/Ch8/include/toy/Parser.h +++ /dev/null @@ -1,683 +0,0 @@ -//===- Parser.h - Toy Language Parser -------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the parser for the Toy language. It processes the Token -// provided by the Lexer and returns an AST. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_PARSER_H -#define TOY_PARSER_H - -#include "toy/AST.h" -#include "toy/Lexer.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include -#include -#include - -namespace toy { - -/// This is a simple recursive parser for the Toy language. It produces a well -/// formed AST from a stream of Token supplied by the Lexer. No semantic checks -/// or symbol resolution is performed. For example, variables are referenced by -/// string and the code could reference an undeclared variable and the parsing -/// succeeds. -class Parser { -public: - /// Create a Parser for the supplied lexer. - Parser(Lexer &lexer) : lexer(lexer) {} - - /// Parse a full Module. A module is a list of function definitions. - std::unique_ptr parseModule() { - lexer.getNextToken(); // prime the lexer - - // Parse functions and structs one at a time and accumulate in this vector. - std::vector> records; - while (true) { - std::unique_ptr record; - switch (lexer.getCurToken()) { - case tok_eof: - break; - case tok_def: - record = parseDefinition(); - break; - case tok_struct: - record = parseStruct(); - break; - default: - return parseError("'def' or 'struct'", - "when parsing top level module records"); - } - if (!record) - break; - records.push_back(std::move(record)); - } - - // If we didn't reach EOF, there was an error during parsing - if (lexer.getCurToken() != tok_eof) - return parseError("nothing", "at end of module"); - - return std::make_unique(std::move(records)); - } - -private: - Lexer &lexer; - - /// Parse a return statement. - /// return :== return ; | return expr ; - std::unique_ptr parseReturn() { - auto loc = lexer.getLastLocation(); - lexer.consume(tok_return); - - // return takes an optional argument - std::optional> expr; - if (lexer.getCurToken() != ';') { - expr = parseExpression(); - if (!expr) - return nullptr; - } - return std::make_unique(std::move(loc), std::move(expr)); - } - - /// Parse a literal number. - /// numberexpr ::= number - std::unique_ptr parseNumberExpr() { - auto loc = lexer.getLastLocation(); - auto result = - std::make_unique(std::move(loc), lexer.getValue()); - lexer.consume(tok_number); - return std::move(result); - } - - /// Parse a literal array expression. - /// tensorLiteral ::= [ literalList ] | number - /// literalList ::= tensorLiteral | tensorLiteral, literalList - std::unique_ptr parseTensorLiteralExpr() { - auto loc = lexer.getLastLocation(); - lexer.consume(Token('[')); - - // Hold the list of values at this nesting level. - std::vector> values; - // Hold the dimensions for all the nesting inside this level. - std::vector dims; - do { - // We can have either another nested array or a number literal. - if (lexer.getCurToken() == '[') { - values.push_back(parseTensorLiteralExpr()); - if (!values.back()) - return nullptr; // parse error in the nested array. - } else { - if (lexer.getCurToken() != tok_number) - return parseError(" or [", "in literal expression"); - values.push_back(parseNumberExpr()); - } - - // End of this list on ']' - if (lexer.getCurToken() == ']') - break; - - // Elements are separated by a comma. - if (lexer.getCurToken() != ',') - return parseError("] or ,", "in literal expression"); - - lexer.getNextToken(); // eat , - } while (true); - if (values.empty()) - return parseError("", "to fill literal expression"); - lexer.getNextToken(); // eat ] - - /// Fill in the dimensions now. First the current nesting level: - dims.push_back(values.size()); - - /// If there is any nested array, process all of them and ensure that - /// dimensions are uniform. - if (llvm::any_of(values, [](std::unique_ptr &expr) { - return llvm::isa(expr.get()); - })) { - auto *firstLiteral = llvm::dyn_cast(values.front().get()); - if (!firstLiteral) - return parseError("uniform well-nested dimensions", - "inside literal expression"); - - // Append the nested dimensions to the current level - auto firstDims = firstLiteral->getDims(); - dims.insert(dims.end(), firstDims.begin(), firstDims.end()); - - // Sanity check that shape is uniform across all elements of the list. - for (auto &expr : values) { - auto *exprLiteral = llvm::cast(expr.get()); - if (!exprLiteral) - return parseError("uniform well-nested dimensions", - "inside literal expression"); - if (exprLiteral->getDims() != firstDims) - return parseError("uniform well-nested dimensions", - "inside literal expression"); - } - } - return std::make_unique(std::move(loc), std::move(values), - std::move(dims)); - } - - /// Parse a literal struct expression. - /// structLiteral ::= { (structLiteral | tensorLiteral)+ } - std::unique_ptr parseStructLiteralExpr() { - auto loc = lexer.getLastLocation(); - lexer.consume(Token('{')); - - // Hold the list of values. - std::vector> values; - do { - // We can have either another nested array or a number literal. - if (lexer.getCurToken() == '[') { - values.push_back(parseTensorLiteralExpr()); - if (!values.back()) - return nullptr; - } else if (lexer.getCurToken() == tok_number) { - values.push_back(parseNumberExpr()); - if (!values.back()) - return nullptr; - } else { - if (lexer.getCurToken() != '{') - return parseError("{, [, or number", - "in struct literal expression"); - values.push_back(parseStructLiteralExpr()); - } - - // End of this list on '}' - if (lexer.getCurToken() == '}') - break; - - // Elements are separated by a comma. - if (lexer.getCurToken() != ',') - return parseError("} or ,", "in struct literal expression"); - - lexer.getNextToken(); // eat , - } while (true); - if (values.empty()) - return parseError("", - "to fill struct literal expression"); - lexer.getNextToken(); // eat } - - return std::make_unique(std::move(loc), - std::move(values)); - } - - /// parenexpr ::= '(' expression ')' - std::unique_ptr parseParenExpr() { - lexer.getNextToken(); // eat (. - auto v = parseExpression(); - if (!v) - return nullptr; - - if (lexer.getCurToken() != ')') - return parseError(")", "to close expression with parentheses"); - lexer.consume(Token(')')); - return v; - } - - /// Parse a call expression. - std::unique_ptr parseCallExpr(llvm::StringRef name, - const Location &loc) { - lexer.consume(Token('(')); - std::vector> args; - if (lexer.getCurToken() != ')') { - while (true) { - if (auto arg = parseExpression()) - args.push_back(std::move(arg)); - else - return nullptr; - - if (lexer.getCurToken() == ')') - break; - - if (lexer.getCurToken() != ',') - return parseError(", or )", "in argument list"); - lexer.getNextToken(); - } - } - lexer.consume(Token(')')); - - // It can be a builtin call to print - if (name == "print") { - if (args.size() != 1) - return parseError("", "as argument to print()"); - - return std::make_unique(loc, std::move(args[0])); - } - - // Call to a user-defined function - return std::make_unique(loc, std::string(name), - std::move(args)); - } - - /// identifierexpr - /// ::= identifier - /// ::= identifier '(' expression ')' - std::unique_ptr parseIdentifierExpr() { - std::string name(lexer.getId()); - - auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat identifier. - - if (lexer.getCurToken() != '(') // Simple variable ref. - return std::make_unique(std::move(loc), name); - - // This is a function call. - return parseCallExpr(name, loc); - } - - /// primary - /// ::= identifierexpr - /// ::= numberexpr - /// ::= parenexpr - /// ::= tensorliteral - std::unique_ptr parsePrimary() { - switch (lexer.getCurToken()) { - default: - llvm::errs() << "unknown token '" << lexer.getCurToken() - << "' when expecting an expression\n"; - return nullptr; - case tok_identifier: - return parseIdentifierExpr(); - case tok_number: - return parseNumberExpr(); - case '(': - return parseParenExpr(); - case '[': - return parseTensorLiteralExpr(); - case '{': - return parseStructLiteralExpr(); - case ';': - return nullptr; - case '}': - return nullptr; - } - } - - /// Recursively parse the right hand side of a binary expression, the ExprPrec - /// argument indicates the precedence of the current binary operator. - /// - /// binoprhs ::= ('+' primary)* - std::unique_ptr parseBinOpRHS(int exprPrec, - std::unique_ptr lhs) { - // If this is a binop, find its precedence. - while (true) { - int tokPrec = getTokPrecedence(); - - // If this is a binop that binds at least as tightly as the current binop, - // consume it, otherwise we are done. - if (tokPrec < exprPrec) - return lhs; - - // Okay, we know this is a binop. - int binOp = lexer.getCurToken(); - lexer.consume(Token(binOp)); - auto loc = lexer.getLastLocation(); - - // Parse the primary expression after the binary operator. - auto rhs = parsePrimary(); - if (!rhs) - return parseError("expression", "to complete binary operator"); - - // If BinOp binds less tightly with rhs than the operator after rhs, let - // the pending operator take rhs as its lhs. - int nextPrec = getTokPrecedence(); - if (tokPrec < nextPrec) { - rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); - if (!rhs) - return nullptr; - } - - // Merge lhs/RHS. - lhs = std::make_unique(std::move(loc), binOp, - std::move(lhs), std::move(rhs)); - } - } - - /// expression::= primary binop rhs - std::unique_ptr parseExpression() { - auto lhs = parsePrimary(); - if (!lhs) - return nullptr; - - return parseBinOpRHS(0, std::move(lhs)); - } - - /// type ::= < shape_list > - /// shape_list ::= num | num , shape_list - std::unique_ptr parseType() { - if (lexer.getCurToken() != '<') - return parseError("<", "to begin type"); - lexer.getNextToken(); // eat < - - auto type = std::make_unique(); - - while (lexer.getCurToken() == tok_number) { - type->shape.push_back(lexer.getValue()); - lexer.getNextToken(); - if (lexer.getCurToken() == ',') - lexer.getNextToken(); - } - - if (lexer.getCurToken() != '>') - return parseError(">", "to end type"); - lexer.getNextToken(); // eat > - return type; - } - - /// Parse either a variable declaration or a call expression. - std::unique_ptr parseDeclarationOrCallExpr() { - auto loc = lexer.getLastLocation(); - std::string id(lexer.getId()); - lexer.consume(tok_identifier); - - // Check for a call expression. - if (lexer.getCurToken() == '(') - return parseCallExpr(id, loc); - - // Otherwise, this is a variable declaration. - return parseTypedDeclaration(id, /*requiresInitializer=*/true, loc); - } - - /// Parse a typed variable declaration. - std::unique_ptr - parseTypedDeclaration(llvm::StringRef typeName, bool requiresInitializer, - const Location &loc) { - // Parse the variable name. - if (lexer.getCurToken() != tok_identifier) - return parseError("name", "in variable declaration"); - std::string id(lexer.getId()); - lexer.getNextToken(); // eat id - - // Parse the initializer. - std::unique_ptr expr; - if (requiresInitializer) { - if (lexer.getCurToken() != '=') - return parseError("initializer", - "in variable declaration"); - lexer.consume(Token('=')); - expr = parseExpression(); - } - - VarType type; - type.name = std::string(typeName); - return std::make_unique(loc, std::move(id), std::move(type), - std::move(expr)); - } - - /// Parse a variable declaration, for either a tensor value or a struct value, - /// with an optionally required initializer. - /// decl ::= var identifier [ type ] (= expr)? - /// decl ::= identifier identifier (= expr)? - std::unique_ptr parseDeclaration(bool requiresInitializer) { - // Check to see if this is a 'var' declaration. - if (lexer.getCurToken() == tok_var) - return parseVarDeclaration(requiresInitializer); - - // Parse the type name. - if (lexer.getCurToken() != tok_identifier) - return parseError("type name", "in variable declaration"); - auto loc = lexer.getLastLocation(); - std::string typeName(lexer.getId()); - lexer.getNextToken(); // eat id - - // Parse the rest of the declaration. - return parseTypedDeclaration(typeName, requiresInitializer, loc); - } - - /// Parse a variable declaration, it starts with a `var` keyword followed by - /// and identifier and an optional type (shape specification) before the - /// optionally required initializer. - /// decl ::= var identifier [ type ] (= expr)? - std::unique_ptr - parseVarDeclaration(bool requiresInitializer) { - if (lexer.getCurToken() != tok_var) - return parseError("var", "to begin declaration"); - auto loc = lexer.getLastLocation(); - lexer.getNextToken(); // eat var - - if (lexer.getCurToken() != tok_identifier) - return parseError("identified", - "after 'var' declaration"); - std::string id(lexer.getId()); - lexer.getNextToken(); // eat id - - std::unique_ptr type; // Type is optional, it can be inferred - if (lexer.getCurToken() == '<') { - type = parseType(); - if (!type) - return nullptr; - } - if (!type) - type = std::make_unique(); - - std::unique_ptr expr; - if (requiresInitializer) { - lexer.consume(Token('=')); - expr = parseExpression(); - } - return std::make_unique(std::move(loc), std::move(id), - std::move(*type), std::move(expr)); - } - - /// Parse a block: a list of expression separated by semicolons and wrapped in - /// curly braces. - /// - /// block ::= { expression_list } - /// expression_list ::= block_expr ; expression_list - /// block_expr ::= decl | "return" | expr - std::unique_ptr parseBlock() { - if (lexer.getCurToken() != '{') - return parseError("{", "to begin block"); - lexer.consume(Token('{')); - - auto exprList = std::make_unique(); - - // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') - lexer.consume(Token(';')); - - while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { - if (lexer.getCurToken() == tok_identifier) { - // Variable declaration or call - auto expr = parseDeclarationOrCallExpr(); - if (!expr) - return nullptr; - exprList->push_back(std::move(expr)); - } else if (lexer.getCurToken() == tok_var) { - // Variable declaration - auto varDecl = parseDeclaration(/*requiresInitializer=*/true); - if (!varDecl) - return nullptr; - exprList->push_back(std::move(varDecl)); - } else if (lexer.getCurToken() == tok_return) { - // Return statement - auto ret = parseReturn(); - if (!ret) - return nullptr; - exprList->push_back(std::move(ret)); - } else { - // General expression - auto expr = parseExpression(); - if (!expr) - return nullptr; - exprList->push_back(std::move(expr)); - } - // Ensure that elements are separated by a semicolon. - if (lexer.getCurToken() != ';') - return parseError(";", "after expression"); - - // Ignore empty expressions: swallow sequences of semicolons. - while (lexer.getCurToken() == ';') - lexer.consume(Token(';')); - } - - if (lexer.getCurToken() != '}') - return parseError("}", "to close block"); - - lexer.consume(Token('}')); - return exprList; - } - - /// prototype ::= def id '(' decl_list ')' - /// decl_list ::= identifier | identifier, decl_list - std::unique_ptr parsePrototype() { - auto loc = lexer.getLastLocation(); - - if (lexer.getCurToken() != tok_def) - return parseError("def", "in prototype"); - lexer.consume(tok_def); - - if (lexer.getCurToken() != tok_identifier) - return parseError("function name", "in prototype"); - - std::string fnName(lexer.getId()); - lexer.consume(tok_identifier); - - if (lexer.getCurToken() != '(') - return parseError("(", "in prototype"); - lexer.consume(Token('(')); - - std::vector> args; - if (lexer.getCurToken() != ')') { - do { - VarType type; - std::string name; - - // Parse either the name of the variable, or its type. - std::string nameOrType(lexer.getId()); - auto loc = lexer.getLastLocation(); - lexer.consume(tok_identifier); - - // If the next token is an identifier, we just parsed the type. - if (lexer.getCurToken() == tok_identifier) { - type.name = std::move(nameOrType); - - // Parse the name. - name = std::string(lexer.getId()); - lexer.consume(tok_identifier); - } else { - // Otherwise, we just parsed the name. - name = std::move(nameOrType); - } - - args.push_back( - std::make_unique(std::move(loc), name, type)); - if (lexer.getCurToken() != ',') - break; - lexer.consume(Token(',')); - if (lexer.getCurToken() != tok_identifier) - return parseError( - "identifier", "after ',' in function parameter list"); - } while (true); - } - if (lexer.getCurToken() != ')') - return parseError(")", "to end function prototype"); - - // success. - lexer.consume(Token(')')); - return std::make_unique(std::move(loc), fnName, - std::move(args)); - } - - /// Parse a function definition, we expect a prototype initiated with the - /// `def` keyword, followed by a block containing a list of expressions. - /// - /// definition ::= prototype block - std::unique_ptr parseDefinition() { - auto proto = parsePrototype(); - if (!proto) - return nullptr; - - if (auto block = parseBlock()) - return std::make_unique(std::move(proto), std::move(block)); - return nullptr; - } - - /// Parse a struct definition, we expect a struct initiated with the - /// `struct` keyword, followed by a block containing a list of variable - /// declarations. - /// - /// definition ::= `struct` identifier `{` decl+ `}` - std::unique_ptr parseStruct() { - auto loc = lexer.getLastLocation(); - lexer.consume(tok_struct); - if (lexer.getCurToken() != tok_identifier) - return parseError("name", "in struct definition"); - std::string name(lexer.getId()); - lexer.consume(tok_identifier); - - // Parse: '{' - if (lexer.getCurToken() != '{') - return parseError("{", "in struct definition"); - lexer.consume(Token('{')); - - // Parse: decl+ - std::vector> decls; - do { - auto decl = parseDeclaration(/*requiresInitializer=*/false); - if (!decl) - return nullptr; - decls.push_back(std::move(decl)); - - if (lexer.getCurToken() != ';') - return parseError(";", - "after variable in struct definition"); - lexer.consume(Token(';')); - } while (lexer.getCurToken() != '}'); - - // Parse: '}' - lexer.consume(Token('}')); - return std::make_unique(loc, name, std::move(decls)); - } - - /// Get the precedence of the pending binary operator token. - int getTokPrecedence() { - if (!isascii(lexer.getCurToken())) - return -1; - - // 1 is lowest precedence. - switch (static_cast(lexer.getCurToken())) { - case '-': - return 20; - case '+': - return 20; - case '*': - return 40; - case '.': - return 60; - default: - return -1; - } - } - - /// Helper function to signal errors while parsing, it takes an argument - /// indicating the expected token and another argument giving more context. - /// Location is retrieved from the lexer to enrich the error message. - template - std::unique_ptr parseError(T &&expected, U &&context = "") { - auto curToken = lexer.getCurToken(); - llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " - << lexer.getLastLocation().col << "): expected '" << expected - << "' " << context << " but has Token " << curToken; - if (isprint(curToken)) - llvm::errs() << " '" << (char)curToken << "'"; - llvm::errs() << "\n"; - return nullptr; - } -}; - -} // namespace toy - -#endif // TOY_PARSER_H diff --git a/mlir/example/Ch8/include/toy/Passes.h b/mlir/example/Ch8/include/toy/Passes.h deleted file mode 100644 index 62471dd..0000000 --- a/mlir/example/Ch8/include/toy/Passes.h +++ /dev/null @@ -1,35 +0,0 @@ -//===- Passes.h - Toy Passes Definition -----------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file exposes the entry points to create compiler passes for Toy. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_PASSES_H -#define TOY_PASSES_H - -#include - -namespace mlir { -class Pass; - -namespace toy { -std::unique_ptr createShapeInferencePass(); - -/// Create a pass for lowering to operations in the `Affine` and `Std` dialects, -/// for a subset of the Toy IR (e.g. matmul). -std::unique_ptr createLowerToAffinePass(); - -/// Create a pass for lowering operations the remaining `Toy` operations, as -/// well as `Affine` and `Std`, to the LLVM dialect for codegen. -std::unique_ptr createLowerToLLVMPass(); - -} // namespace toy -} // namespace mlir - -#endif // TOY_PASSES_H diff --git a/mlir/example/Ch8/include/toy/ShapeInferenceInterface.h b/mlir/example/Ch8/include/toy/ShapeInferenceInterface.h deleted file mode 100644 index cfe5a87..0000000 --- a/mlir/example/Ch8/include/toy/ShapeInferenceInterface.h +++ /dev/null @@ -1,28 +0,0 @@ -//===- ShapeInferenceInterface.h - Interface definitions for ShapeInference -=// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains the declarations of the shape inference interfaces defined -// in ShapeInferenceInterface.td. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ -#define MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ - -#include "mlir/IR/OpDefinition.h" - -namespace mlir { -namespace toy { - -/// Include the auto-generated declarations. -#include "toy/ShapeInferenceOpInterfaces.h.inc" - -} // namespace toy -} // namespace mlir - -#endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ diff --git a/mlir/example/Ch8/include/toy/ShapeInferenceInterface.td b/mlir/example/Ch8/include/toy/ShapeInferenceInterface.td deleted file mode 100644 index 2279015..0000000 --- a/mlir/example/Ch8/include/toy/ShapeInferenceInterface.td +++ /dev/null @@ -1,30 +0,0 @@ -//===- ShapeInferenceInterface.td - Shape Inference Interface -*- tablegen -==// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Defines the operations of the Shape Inference Op Interface. -// -//===----------------------------------------------------------------------===// - -#ifndef SHAPE_INFERENCE_INTERFACE -#define SHAPE_INFERENCE_INTERFACE - -include "mlir/IR/OpBase.td" - -def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { - let description = [{ - Interface to access a registered method to infer the return types for an - operation that can be used during type inference. - }]; - - let methods = [ - InterfaceMethod<"Infer and set the output shape for the current operation.", - "void", "inferShapes"> - ]; -} - -#endif // SHAPE_INFERENCE_INTERFACE diff --git a/mlir/example/Ch8/matmul.toy b/mlir/example/Ch8/matmul.toy deleted file mode 100644 index b1a0cdb..0000000 --- a/mlir/example/Ch8/matmul.toy +++ /dev/null @@ -1,14 +0,0 @@ -def main() { - # Define a variable `a` with shape <2, 3>, initialized with the literal value. - # The shape is inferred from the supplied literal. - var a = [[1, 2, 3], [4, 5, 6]]; - - # b is identical to a, the literal tensor is implicitly reshaped: defining new - # variables is the way to reshape tensors (element count must match). - var b<2, 3> = [1, 2, 3, 4, 5, 6]; - - # transpose() and print() are the only builtin, the following will transpose - # a and b and perform an element-wise multiplication before printing the result. - # print(a * b + b); - print(matmul(a, transpose(b))); -} diff --git a/mlir/example/Ch8/matmul.toy.mlir b/mlir/example/Ch8/matmul.toy.mlir deleted file mode 100644 index 5a0cd7e..0000000 --- a/mlir/example/Ch8/matmul.toy.mlir +++ /dev/null @@ -1,16 +0,0 @@ -toy.func private @matmul_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { - %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> - %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> - %2 = toy.matmul(%0 : tensor<*xf64>, %1 : tensor<*xf64>) to tensor<*xf64> - toy.return %2 : tensor<*xf64> -} - -toy.func @main() { - %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> - %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> - %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> - %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<3x2xf64> - %4 = toy.generic_call @matmul_transpose(%1, %3) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<*xf64> - toy.print %4 : tensor<*xf64> - toy.return -} diff --git a/mlir/example/Ch8/mlir/Dialect.cpp b/mlir/example/Ch8/mlir/Dialect.cpp deleted file mode 100644 index 0f5152d..0000000 --- a/mlir/example/Ch8/mlir/Dialect.cpp +++ /dev/null @@ -1,718 +0,0 @@ -//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the dialect for the Toy IR: custom type parsing and -// operation verification. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/FunctionImplementation.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include -#include -#include -#include -#include - -using namespace mlir; -using namespace mlir::toy; - -#include "toy/Dialect.cpp.inc" - -//===----------------------------------------------------------------------===// -// ToyInlinerInterface -//===----------------------------------------------------------------------===// - -/// This class defines the interface for handling inlining with Toy -/// operations. -struct ToyInlinerInterface : public DialectInlinerInterface { - using DialectInlinerInterface::DialectInlinerInterface; - - //===--------------------------------------------------------------------===// - // Analysis Hooks - //===--------------------------------------------------------------------===// - - /// All call operations within toy can be inlined. - bool isLegalToInline(Operation *call, Operation *callable, - bool wouldBeCloned) const final { - return true; - } - - /// All operations within toy can be inlined. - bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { - return true; - } - - // All functions within toy can be inlined. - bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { - return true; - } - - //===--------------------------------------------------------------------===// - // Transformation Hooks - //===--------------------------------------------------------------------===// - - /// Handle the given inlined terminator(toy.return) by replacing it with a new - /// operation as necessary. - void handleTerminator(Operation *op, ValueRange valuesToRepl) const final { - // Only "toy.return" needs to be handled here. - auto returnOp = cast(op); - - // Replace the values directly with the return operands. - assert(returnOp.getNumOperands() == valuesToRepl.size()); - for (const auto &it : llvm::enumerate(returnOp.getOperands())) - valuesToRepl[it.index()].replaceAllUsesWith(it.value()); - } - - /// Attempts to materialize a conversion for a type mismatch between a call - /// from this dialect, and a callable region. This method should generate an - /// operation that takes 'input' as the only operand, and produces a single - /// result of 'resultType'. If a conversion can not be generated, nullptr - /// should be returned. - Operation *materializeCallConversion(OpBuilder &builder, Value input, - Type resultType, - Location conversionLoc) const final { - return builder.create(conversionLoc, resultType, input); - } -}; - -//===----------------------------------------------------------------------===// -// Toy Operations -//===----------------------------------------------------------------------===// - -/// A generalized parser for binary operations. This parses the different forms -/// of 'printBinaryOp' below. -static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - SmallVector operands; - SMLoc operandsLoc = parser.getCurrentLocation(); - Type type; - if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) - return mlir::failure(); - - // If the type is a function type, it contains the input and result types of - // this operation. - if (FunctionType funcType = llvm::dyn_cast(type)) { - if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, - result.operands)) - return mlir::failure(); - result.addTypes(funcType.getResults()); - return mlir::success(); - } - - // Otherwise, the parsed type is the type of both operands and results. - if (parser.resolveOperands(operands, type, result.operands)) - return mlir::failure(); - result.addTypes(type); - return mlir::success(); -} - -/// A generalized printer for binary operations. It prints in two different -/// forms depending on if all of the types match. -static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { - printer << " " << op->getOperands(); - printer.printOptionalAttrDict(op->getAttrs()); - printer << " : "; - - // If all of the types are the same, print the type directly. - Type resultType = *op->result_type_begin(); - if (llvm::all_of(op->getOperandTypes(), - [=](Type type) { return type == resultType; })) { - printer << resultType; - return; - } - - // Otherwise, print a functional type. - printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); -} - -//===----------------------------------------------------------------------===// -// ConstantOp -//===----------------------------------------------------------------------===// - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - double value) { - auto dataType = RankedTensorType::get({}, builder.getF64Type()); - auto dataAttribute = DenseElementsAttr::get(dataType, value); - ConstantOp::build(builder, state, dataType, dataAttribute); -} - -/// The 'OpAsmParser' class provides a collection of methods for parsing -/// various punctuation, as well as attributes, operands, types, etc. Each of -/// these methods returns a `ParseResult`. This class is a wrapper around -/// `LogicalResult` that can be converted to a boolean `true` value on failure, -/// or `false` on success. This allows for easily chaining together a set of -/// parser rules. These rules are used to populate an `mlir::OperationState` -/// similarly to the `build` methods described above. -mlir::ParseResult ConstantOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - mlir::DenseElementsAttr value; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(value, "value", result.attributes)) - return failure(); - - result.addTypes(value.getType()); - return success(); -} - -/// The 'OpAsmPrinter' class is a stream that allows for formatting -/// strings, attributes, operands, types, etc. -void ConstantOp::print(mlir::OpAsmPrinter &printer) { - printer << " "; - printer.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); - printer << getValue(); -} - -/// Verify that the given attribute value is valid for the given type. -static llvm::LogicalResult verifyConstantForType(mlir::Type type, - mlir::Attribute opaqueValue, - mlir::Operation *op) { - if (llvm::isa(type)) { - // Check that the value is an elements attribute. - auto attrValue = llvm::dyn_cast(opaqueValue); - if (!attrValue) - return op->emitError("constant of TensorType must be initialized by " - "a DenseFPElementsAttr, got ") - << opaqueValue; - - // If the return type of the constant is not an unranked tensor, the shape - // must match the shape of the attribute holding the data. - auto resultType = llvm::dyn_cast(type); - if (!resultType) - return success(); - - // Check that the rank of the attribute type matches the rank of the - // constant result type. - auto attrType = llvm::cast(attrValue.getType()); - if (attrType.getRank() != resultType.getRank()) { - return op->emitOpError("return type must match the one of the attached " - "value attribute: ") - << attrType.getRank() << " != " << resultType.getRank(); - } - - // Check that each of the dimensions match between the two types. - for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) { - if (attrType.getShape()[dim] != resultType.getShape()[dim]) { - return op->emitOpError( - "return type shape mismatches its attribute at dimension ") - << dim << ": " << attrType.getShape()[dim] - << " != " << resultType.getShape()[dim]; - } - } - return mlir::success(); - } - auto resultType = llvm::cast(type); - llvm::ArrayRef resultElementTypes = resultType.getElementTypes(); - - // Verify that the initializer is an Array. - auto attrValue = llvm::dyn_cast(opaqueValue); - if (!attrValue || attrValue.getValue().size() != resultElementTypes.size()) - return op->emitError("constant of StructType must be initialized by an " - "ArrayAttr with the same number of elements, got ") - << opaqueValue; - - // Check that each of the elements are valid. - llvm::ArrayRef attrElementValues = attrValue.getValue(); - for (const auto it : llvm::zip(resultElementTypes, attrElementValues)) - if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op))) - return mlir::failure(); - return mlir::success(); -} - -/// Verifier for the constant operation. This corresponds to the `::verify(...)` -/// in the op definition. -llvm::LogicalResult ConstantOp::verify() { - return verifyConstantForType(getResult().getType(), getValue(), *this); -} - -llvm::LogicalResult StructConstantOp::verify() { - return verifyConstantForType(getResult().getType(), getValue(), *this); -} - -/// Infer the output shape of the ConstantOp, this is required by the shape -/// inference interface. -void ConstantOp::inferShapes() { - getResult().setType(cast(getValue().getType())); -} - -//===----------------------------------------------------------------------===// -// AddOp -//===----------------------------------------------------------------------===// - -void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value lhs, mlir::Value rhs) { - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); - state.addOperands({lhs, rhs}); -} - -mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - return parseBinaryOp(parser, result); -} - -void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } - -/// Infer the output shape of the AddOp, this is required by the shape inference -/// interface. -void AddOp::inferShapes() { getResult().setType(getLhs().getType()); } - -//===----------------------------------------------------------------------===// -// CastOp -//===----------------------------------------------------------------------===// - -/// Infer the output shape of the CastOp, this is required by the shape -/// inference interface. -void CastOp::inferShapes() { getResult().setType(getInput().getType()); } - -/// Returns true if the given set of input and result types are compatible with -/// this cast operation. This is required by the `CastOpInterface` to verify -/// this operation and provide other additional utilities. -bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { - if (inputs.size() != 1 || outputs.size() != 1) - return false; - // The inputs must be Tensors with the same element type. - TensorType input = llvm::dyn_cast(inputs.front()); - TensorType output = llvm::dyn_cast(outputs.front()); - if (!input || !output || input.getElementType() != output.getElementType()) - return false; - // The shape is required to match if both types are ranked. - return !input.hasRank() || !output.hasRank() || input == output; -} - -//===----------------------------------------------------------------------===// -// FuncOp -//===----------------------------------------------------------------------===// - -void FuncOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - llvm::StringRef name, mlir::FunctionType type, - llvm::ArrayRef attrs) { - // FunctionOpInterface provides a convenient `build` method that will populate - // the state of our FuncOp, and create an entry block. - buildWithEntryBlock(builder, state, name, type, attrs, type.getInputs()); -} - -mlir::ParseResult FuncOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - // Dispatch to the FunctionOpInterface provided utility method that parses the - // function operation. - auto buildFuncType = - [](mlir::Builder &builder, llvm::ArrayRef argTypes, - llvm::ArrayRef results, - mlir::function_interface_impl::VariadicFlag, - std::string &) { return builder.getFunctionType(argTypes, results); }; - - return mlir::function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType, - getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); -} - -void FuncOp::print(mlir::OpAsmPrinter &p) { - // Dispatch to the FunctionOpInterface provided utility method that prints the - // function operation. - mlir::function_interface_impl::printFunctionOp( - p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); -} - -//===----------------------------------------------------------------------===// -// GenericCallOp -//===----------------------------------------------------------------------===// - -void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { - // Generic call always returns an unranked Tensor initially. - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); - state.addOperands(arguments); - state.addAttribute("callee", - mlir::SymbolRefAttr::get(builder.getContext(), callee)); -} - -/// Return the callee of the generic call operation, this is required by the -/// call interface. -CallInterfaceCallable GenericCallOp::getCallableForCallee() { - return (*this)->getAttrOfType("callee"); -} - -/// Set the callee for the generic call operation, this is required by the call -/// interface. -void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { - (*this)->setAttr("callee", callee.get()); -} - -/// Get the argument operands to the called function, this is required by the -/// call interface. -Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); } - -/// Get the argument operands to the called function as a mutable range, this is -/// required by the call interface. -MutableOperandRange GenericCallOp::getArgOperandsMutable() { - return getInputsMutable(); -} - -//===----------------------------------------------------------------------===// -// MulOp -//===----------------------------------------------------------------------===// - -void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value lhs, mlir::Value rhs) { - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); - state.addOperands({lhs, rhs}); -} - -mlir::ParseResult MulOp::parse(mlir::OpAsmParser &parser, - mlir::OperationState &result) { - return parseBinaryOp(parser, result); -} - -void MulOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); } - -/// Infer the output shape of the MulOp, this is required by the shape inference -/// interface. -void MulOp::inferShapes() { getResult().setType(getLhs().getType()); } - -//===----------------------------------------------------------------------===// -// ReturnOp -//===----------------------------------------------------------------------===// - -llvm::LogicalResult ReturnOp::verify() { - // We know that the parent operation is a function, because of the 'HasParent' - // trait attached to the operation definition. - auto function = cast((*this)->getParentOp()); - - /// ReturnOps can only have a single optional operand. - if (getNumOperands() > 1) - return emitOpError() << "expects at most 1 return operand"; - - // The operand number and types must match the function signature. - const auto &results = function.getFunctionType().getResults(); - if (getNumOperands() != results.size()) - return emitOpError() << "does not return the same number of values (" - << getNumOperands() << ") as the enclosing function (" - << results.size() << ")"; - - // If the operation does not have an input, we are done. - if (!hasOperand()) - return mlir::success(); - - auto inputType = *operand_type_begin(); - auto resultType = results.front(); - - // Check that the result type of the function matches the operand type. - if (inputType == resultType || - llvm::isa(inputType) || - llvm::isa(resultType)) - return mlir::success(); - - return emitError() << "type of return operand (" << inputType - << ") doesn't match function result type (" << resultType - << ")"; -} - -//===----------------------------------------------------------------------===// -// StructAccessOp -//===----------------------------------------------------------------------===// - -void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state, - mlir::Value input, size_t index) { - // Extract the result type from the input type. - StructType structTy = llvm::cast(input.getType()); - assert(index < structTy.getNumElementTypes()); - mlir::Type resultType = structTy.getElementTypes()[index]; - - // Call into the auto-generated build method. - build(b, state, resultType, input, b.getI64IntegerAttr(index)); -} - -llvm::LogicalResult StructAccessOp::verify() { - StructType structTy = llvm::cast(getInput().getType()); - size_t indexValue = getIndex(); - if (indexValue >= structTy.getNumElementTypes()) - return emitOpError() - << "index should be within the range of the input struct type"; - mlir::Type resultType = getResult().getType(); - if (resultType != structTy.getElementTypes()[indexValue]) - return emitOpError() << "must have the same result type as the struct " - "element referred to by the index"; - return mlir::success(); -} - -//===----------------------------------------------------------------------===// -// TransposeOp -//===----------------------------------------------------------------------===// - -void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value value) { - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); - state.addOperands(value); -} - -void TransposeOp::inferShapes() { - auto arrayTy = llvm::cast(getOperand().getType()); - SmallVector dims(llvm::reverse(arrayTy.getShape())); - getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); -} - -llvm::LogicalResult TransposeOp::verify() { - auto inputType = llvm::dyn_cast(getOperand().getType()); - auto resultType = llvm::dyn_cast(getType()); - if (!inputType || !resultType) - return mlir::success(); - - auto inputShape = inputType.getShape(); - if (!std::equal(inputShape.begin(), inputShape.end(), - resultType.getShape().rbegin())) { - return emitError() - << "expected result shape to be a transpose of the input"; - } - return mlir::success(); -} - -//===----------------------------------------------------------------------===// -// MatMulOp -//===----------------------------------------------------------------------===// - -void MatMulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value lhs, mlir::Value rhs) { - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); - state.addOperands({lhs, rhs}); -} - -/// Infer the output shape of the MatMulOp, this is required by the shape -/// inference interface. -void MatMulOp::inferShapes() { - RankedTensorType lhsType = - llvm::dyn_cast(getLhs().getType()); - RankedTensorType rhsType = - llvm::dyn_cast(getRhs().getType()); - auto lhsShape = lhsType.getShape(); - auto rhsShape = rhsType.getShape(); - RankedTensorType res_type = RankedTensorType::get({lhsShape[0], rhsShape[1]}, - lhsType.getElementType()); - getResult().setType(res_type); -} - -llvm::LogicalResult MatMulOp::verify() { - auto lhsType = llvm::dyn_cast(getLhs().getType()); - auto rhsType = llvm::dyn_cast(getRhs().getType()); - auto resultType = llvm::dyn_cast(getType()); - - if (!lhsType || !rhsType || !resultType) - return mlir::success(); - - auto lhsShape = lhsType.getShape(); - auto rhsShape = rhsType.getShape(); - - if (lhsShape.size() != 2 || rhsShape.size() != 2) { - return emitOpError() << "expected 2D matrix"; - } - - if (lhsShape[1] != rhsShape[0]) { - return emitOpError() << "expected dimension to match" - << "the shape of lhs is [" << lhsShape[0] << ", " - << lhsShape[1] << "] " - << "the shape of rhs is [" << rhsShape[0] << ", " - << rhsShape[1] << "] " - << "but the dimension " << lhsShape[1] - << "!=" << rhsShape[0] << '\n'; - } - - return mlir::success(); -} - -//===----------------------------------------------------------------------===// -// Toy Types -//===----------------------------------------------------------------------===// - -namespace mlir { -namespace toy { -namespace detail { -/// This class represents the internal storage of the Toy `StructType`. -struct StructTypeStorage : public mlir::TypeStorage { - /// The `KeyTy` is a required type that provides an interface for the storage - /// instance. This type will be used when uniquing an instance of the type - /// storage. For our struct type, we will unique each instance structurally on - /// the elements that it contains. - using KeyTy = llvm::ArrayRef; - - /// A constructor for the type storage instance. - StructTypeStorage(llvm::ArrayRef elementTypes) - : elementTypes(elementTypes) {} - - /// Define the comparison function for the key type with the current storage - /// instance. This is used when constructing a new instance to ensure that we - /// haven't already uniqued an instance of the given key. - bool operator==(const KeyTy &key) const { return key == elementTypes; } - - /// Define a hash function for the key type. This is used when uniquing - /// instances of the storage, see the `StructType::get` method. - /// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type - /// have hash functions available, so we could just omit this entirely. - static llvm::hash_code hashKey(const KeyTy &key) { - return llvm::hash_value(key); - } - - /// Define a construction function for the key type from a set of parameters. - /// These parameters will be provided when constructing the storage instance - /// itself. - /// Note: This method isn't necessary because KeyTy can be directly - /// constructed with the given parameters. - static KeyTy getKey(llvm::ArrayRef elementTypes) { - return KeyTy(elementTypes); - } - - /// Define a construction method for creating a new instance of this storage. - /// This method takes an instance of a storage allocator, and an instance of a - /// `KeyTy`. The given allocator must be used for *all* necessary dynamic - /// allocations used to create the type storage and its internal. - static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - const KeyTy &key) { - // Copy the elements from the provided `KeyTy` into the allocator. - llvm::ArrayRef elementTypes = allocator.copyInto(key); - - // Allocate the storage instance and construct it. - return new (allocator.allocate()) - StructTypeStorage(elementTypes); - } - - /// The following field contains the element types of the struct. - llvm::ArrayRef elementTypes; -}; -} // namespace detail -} // namespace toy -} // namespace mlir - -/// Create an instance of a `StructType` with the given element types. There -/// *must* be at least one element type. -StructType StructType::get(llvm::ArrayRef elementTypes) { - assert(!elementTypes.empty() && "expected at least 1 element type"); - - // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance - // of this type. The first parameter is the context to unique in. The - // parameters after the context are forwarded to the storage instance. - mlir::MLIRContext *ctx = elementTypes.front().getContext(); - return Base::get(ctx, elementTypes); -} - -/// Returns the element types of this struct type. -llvm::ArrayRef StructType::getElementTypes() { - // 'getImpl' returns a pointer to the internal storage instance. - return getImpl()->elementTypes; -} - -/// Parse an instance of a type registered to the toy dialect. -mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const { - // Parse a struct type in the following form: - // struct-type ::= `struct` `<` type (`,` type)* `>` - - // NOTE: All MLIR parser function return a ParseResult. This is a - // specialization of LogicalResult that auto-converts to a `true` boolean - // value on failure to allow for chaining, but may be used with explicit - // `mlir::failed/mlir::succeeded` as desired. - - // Parse: `struct` `<` - if (parser.parseKeyword("struct") || parser.parseLess()) - return Type(); - - // Parse the element types of the struct. - SmallVector elementTypes; - do { - // Parse the current element type. - SMLoc typeLoc = parser.getCurrentLocation(); - mlir::Type elementType; - if (parser.parseType(elementType)) - return nullptr; - - // Check that the type is either a TensorType or another StructType. - if (!llvm::isa(elementType)) { - parser.emitError(typeLoc, "element type for a struct must either " - "be a TensorType or a StructType, got: ") - << elementType; - return Type(); - } - elementTypes.push_back(elementType); - - // Parse the optional: `,` - } while (succeeded(parser.parseOptionalComma())); - - // Parse: `>` - if (parser.parseGreater()) - return Type(); - return StructType::get(elementTypes); -} - -/// Print an instance of a type registered to the toy dialect. -void ToyDialect::printType(mlir::Type type, - mlir::DialectAsmPrinter &printer) const { - // Currently the only toy type is a struct type. - StructType structType = llvm::cast(type); - - // Print the struct type according to the parser format. - printer << "struct<"; - llvm::interleaveComma(structType.getElementTypes(), printer); - printer << '>'; -} - -//===----------------------------------------------------------------------===// -// TableGen'd op method definitions -//===----------------------------------------------------------------------===// - -#define GET_OP_CLASSES -#include "toy/Ops.cpp.inc" - -//===----------------------------------------------------------------------===// -// ToyDialect -//===----------------------------------------------------------------------===// - -/// Dialect initialization, the instance will be owned by the context. This is -/// the point of registration of types and operations for the dialect. -void ToyDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "toy/Ops.cpp.inc" - >(); - addInterfaces(); - addTypes(); -} - -mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, - mlir::Attribute value, - mlir::Type type, - mlir::Location loc) { - if (llvm::isa(type)) - return builder.create(loc, type, - llvm::cast(value)); - return builder.create(loc, type, - llvm::cast(value)); -} diff --git a/mlir/example/Ch8/mlir/LowerToAffineLoops.cpp b/mlir/example/Ch8/mlir/LowerToAffineLoops.cpp deleted file mode 100644 index c21ded3..0000000 --- a/mlir/example/Ch8/mlir/LowerToAffineLoops.cpp +++ /dev/null @@ -1,470 +0,0 @@ -//====- LowerToAffineLoops.cpp - Partial lowering from Toy to Affine+Std --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a partial lowering of Toy operations to a combination of -// affine loops, memref operations and standard operations. This lowering -// expects that all calls have been inlined, and all shapes have been resolved. -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectRegistry.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/TypeID.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/Support/Casting.h" -#include -#include -#include -#include -#include - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns -//===----------------------------------------------------------------------===// - -/// Convert the given RankedTensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(RankedTensorType type) { - return MemRefType::get(type.getShape(), type.getElementType()); -} - -/// Insert an allocation and deallocation for the given MemRefType. -static Value insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter) { - auto alloc = rewriter.create(loc, type); - - // Make sure to allocate at the beginning of the block. - auto *parentBlock = alloc->getBlock(); - alloc->moveBefore(&parentBlock->front()); - - // Make sure to deallocate this alloc at the end of the block. This is fine - // as toy functions have no control flow. - auto dealloc = rewriter.create(loc, alloc); - dealloc->moveBefore(&parentBlock->back()); - return alloc; -} - -/// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, - LoopIterationFn processIteration) { - auto tensorType = llvm::cast((*op->result_type_begin())); - auto loc = op->getLoc(); - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); - - // Create a nest of affine loops, with one loop per dimension of the shape. - // The buildAffineLoopNest function takes a callback that is used to construct - // the body of the innermost loop given a builder, a location and a range of - // loop induction variables. - SmallVector lowerBounds(tensorType.getRank(), /*Value=*/0); - SmallVector steps(tensorType.getRank(), /*Value=*/1); - affine::buildAffineLoopNest( - rewriter, loc, lowerBounds, tensorType.getShape(), steps, - [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, - // and the loop induction variables. This function will return the value - // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); - nestedBuilder.create(loc, valueToStore, alloc, - ivs); - }); - - // Replace this operation with the generated alloc. - rewriter.replaceOp(op, alloc); -} - -namespace { -//===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations -//===----------------------------------------------------------------------===// - -template -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = builder.create( - loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = builder.create( - loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return builder.create(loc, loadedLhs, - loadedRhs); - }); - return success(); - } -}; -using AddOpLowering = BinaryOpLowering; -using MulOpLowering = BinaryOpLowering; - -//===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations -//===----------------------------------------------------------------------===// - -struct ConstantOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { - DenseElementsAttr constantValue = op.getValue(); - Location loc = op.getLoc(); - - // When lowering the constant operation, we allocate and assign the constant - // values to a corresponding memref allocation. - auto tensorType = llvm::cast(op.getType()); - auto memRefType = convertTensorToMemRef(tensorType); - auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); - - // We will be generating constant indices up-to the largest dimension. - // Create these constants up-front to avoid large amounts of redundant - // operations. - auto valueShape = memRefType.getShape(); - SmallVector constantIndices; - - if (!valueShape.empty()) { - for (auto i : llvm::seq(0, *llvm::max_element(valueShape))) - constantIndices.push_back( - rewriter.create(loc, i)); - } else { - // This is the case of a tensor of rank 0. - constantIndices.push_back( - rewriter.create(loc, 0)); - } - - // The constant operation represents a multi-dimensional constant, so we - // will need to generate a store for each of the elements. The following - // functor recursively walks the dimensions of the constant shape, - // generating a store when the recursion hits the base case. - SmallVector indices; - auto valueIt = constantValue.value_begin(); - std::function storeElements = [&](uint64_t dimension) { - // The last dimension is the base case of the recursion, at this point - // we store the element at the given index. - if (dimension == valueShape.size()) { - rewriter.create( - loc, rewriter.create(loc, *valueIt++), alloc, - llvm::ArrayRef(indices)); - return; - } - - // Otherwise, iterate over the current dimension and add the indices to - // the list. - for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { - indices.push_back(constantIndices[i]); - storeElements(dimension + 1); - indices.pop_back(); - } - }; - - // Start the element storing recursion from the first dimension. - storeElements(/*dimension=*/0); - - // Replace this operation with the generated alloc. - rewriter.replaceOp(op, alloc); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations -//===----------------------------------------------------------------------===// - -struct FuncOpLowering : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(toy::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - // We only lower the main function as we expect that all other functions - // have been inlined. - if (op.getName() != "main") - return failure(); - - // Verify that the given main has no inputs and results. - if (op.getNumArguments() || op.getFunctionType().getNumResults()) { - return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { - diag << "expected 'main' to have 0 inputs and 0 results"; - }); - } - - // Create a new non-toy function, with the same region. - auto func = rewriter.create(op.getLoc(), op.getName(), - op.getFunctionType()); - rewriter.inlineRegionBefore(op.getRegion(), func.getBody(), func.end()); - rewriter.eraseOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations -//===----------------------------------------------------------------------===// - -struct PrintOpLowering : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - // We don't lower "toy.print" in this pass, but we need to update its - // operands. - rewriter.modifyOpInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations -//===----------------------------------------------------------------------===// - -struct ReturnOpLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { - // During this lowering, we expect that all function calls have been - // inlined. - if (op.hasOperand()) - return failure(); - - // We lower "toy.return" directly to "func.return". - rewriter.replaceOpWithNewOp(op); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations -//===----------------------------------------------------------------------===// - -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector reverseIvs(llvm::reverse(loopIvs)); - return builder.create(loc, input, - reverseIvs); - }); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: MatMul operations -//===----------------------------------------------------------------------===// - -struct MatMulOpLowering : public ConversionPattern { - MatMulOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::MatMulOp::getOperationName(), 1, ctx) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - - RankedTensorType lhsType = - llvm::dyn_cast(op->getOperand(0).getType()); - RankedTensorType rhsType = - llvm::dyn_cast(op->getOperand(1).getType()); - auto lhsShape = lhsType.getShape(); - auto rhsShape = rhsType.getShape(); - - auto tensorType = - llvm::dyn_cast((*op->result_type_begin())); - - auto elemType = llvm::dyn_cast(tensorType.getElementType()); - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter); - - SmallVector lowerBounds(tensorType.getRank() + 1, /*Value=*/0); - SmallVector steps(tensorType.getRank() + 1, /*Value=*/1); - SmallVector upperBounds{lhsShape[0], rhsShape[0], rhsShape[1]}; - - // add initialization of result tensor. - // Create a nest of affine loops to initialize the result tensor to 0. - affine::buildAffineLoopNest( - rewriter, loc, {0, 0}, tensorType.getShape(), {1, 1}, - [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Create a constant float value of 0.0. - auto valueToStore = nestedBuilder.create( - loc, llvm::APFloat(0.0), elemType); - // Store the constant value into the allocated memory. - nestedBuilder.create(loc, valueToStore, alloc, - ivs); - }); - - // Create a nest of affine loops for matrix multiplication. - affine::buildAffineLoopNest( - rewriter, loc, lowerBounds, upperBounds, steps, - [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Extract loop induction variables. - Value m = ivs[0]; - Value k = ivs[1]; - Value n = ivs[2]; - - // Create an adaptor for the remapped operands of the MatMulOp. - toy::MatMulOpAdaptor matmulAdaptor(operands); - - // Load elements from the left-hand side and right-hand side matrices. - auto loadedLhs = nestedBuilder.create( - loc, matmulAdaptor.getLhs(), ValueRange{m, k}); - auto loadedRhs = nestedBuilder.create( - loc, matmulAdaptor.getRhs(), ValueRange{k, n}); - // Load elements from the result tensor from initial process above. - auto loadedRes = nestedBuilder.create( - loc, alloc, ValueRange{m, n}); - - // Perform the multiplication and addition operations. - auto mulop = - nestedBuilder.create(loc, loadedLhs, loadedRhs); - auto valueToStore = - nestedBuilder.create(loc, loadedRes, mulop); - - // Store the result back into the allocated memory. - nestedBuilder.create(loc, valueToStore, alloc, - ValueRange{m, n}); - }); - - // Replace this operation with the generated alloc. - rewriter.replaceOp(op, alloc); - - return success(); - } -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// ToyToAffineLoweringPass -//===----------------------------------------------------------------------===// - -/// This is a partial lowering to affine loops of the toy operations that are -/// computationally intensive (like matmul for example...) while keeping the -/// rest of the code in the Toy dialect. -namespace { -struct ToyToAffineLoweringPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToAffineLoweringPass) - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() final; -}; -} // namespace - -void ToyToAffineLoweringPass::runOnOperation() { - // The first thing to define is the conversion target. This will define the - // final target for this lowering. - ConversionTarget target(getContext()); - - // We define the specific operations, or dialects, that are legal targets for - // this lowering. In our case, we are lowering to a combination of the - // `Affine`, `Arith`, `Func`, and `MemRef` dialects. - target.addLegalDialect(); - - // We also define the Toy dialect as Illegal so that the conversion will fail - // if any of these operations are *not* converted. Given that we actually want - // a partial lowering, we explicitly mark the Toy operations that don't want - // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands - // to be updated though (as we convert from TensorType to MemRefType), so we - // only treat it as `legal` if its operands are legal. - target.addIllegalDialect(); - target.addDynamicallyLegalOp([](toy::PrintOp op) { - return llvm::none_of(op->getOperandTypes(), - [](Type type) { return llvm::isa(type); }); - }); - - // Now that the conversion target has been defined, we just need to provide - // the set of patterns that will lower the Toy operations. - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - - // With the target and rewrite patterns defined, we can now attempt the - // conversion. The conversion will signal failure if any of our `illegal` - // operations were not converted successfully. - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); -} - -/// Create a pass for lowering operations in the `Affine` and `Std` dialects, -/// for a subset of the Toy IR (e.g. matmul). -std::unique_ptr mlir::toy::createLowerToAffinePass() { - return std::make_unique(); -} diff --git a/mlir/example/Ch8/mlir/LowerToLLVM.cpp b/mlir/example/Ch8/mlir/LowerToLLVM.cpp deleted file mode 100644 index 3ad70e7..0000000 --- a/mlir/example/Ch8/mlir/LowerToLLVM.cpp +++ /dev/null @@ -1,240 +0,0 @@ -//====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements full lowering of Toy operations to LLVM MLIR dialect. -// 'toy.print' is lowered to a loop nest that calls `printf` on each element of -// the input array. The file also sets up the ToyToLLVMLoweringPass. This pass -// lowers the combination of Arithmetic + Affine + SCF + Func dialects to the -// LLVM one: -// -// Affine -- -// | -// v -// Arithmetic + Func --> LLVM (Dialect) -// ^ -// | -// 'toy.print' --> Loop (SCF) -- -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/TypeID.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" - -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/Support/Casting.h" -#include -#include - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// ToyToLLVM RewritePatterns -//===----------------------------------------------------------------------===// - -namespace { -/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual -/// elements of the array. -class PrintOpLowering : public ConversionPattern { -public: - explicit PrintOpLowering(MLIRContext *context) - : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto *context = rewriter.getContext(); - auto memRefType = llvm::cast((*op->operand_type_begin())); - auto memRefShape = memRefType.getShape(); - auto loc = op->getLoc(); - - ModuleOp parentModule = op->getParentOfType(); - - // Get a symbol reference to the printf function, inserting it if necessary. - auto printfRef = getOrInsertPrintf(rewriter, parentModule); - Value formatSpecifierCst = getOrCreateGlobalString( - loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule); - Value newLineCst = getOrCreateGlobalString( - loc, rewriter, "nl", StringRef("\n\0", 2), parentModule); - - // Create a loop for each of the dimensions within the shape. - SmallVector loopIvs; - for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) { - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = - rewriter.create(loc, memRefShape[i]); - auto step = rewriter.create(loc, 1); - auto loop = - rewriter.create(loc, lowerBound, upperBound, step); - for (Operation &nested : *loop.getBody()) - rewriter.eraseOp(&nested); - loopIvs.push_back(loop.getInductionVar()); - - // Terminate the loop body. - rewriter.setInsertionPointToEnd(loop.getBody()); - - // Insert a newline after each of the inner dimensions of the shape. - if (i != e - 1) - rewriter.create(loc, getPrintfType(context), printfRef, - newLineCst); - rewriter.create(loc); - rewriter.setInsertionPointToStart(loop.getBody()); - } - - // Generate a call to printf for the current element of the loop. - auto printOp = cast(op); - auto elementLoad = - rewriter.create(loc, printOp.getInput(), loopIvs); - rewriter.create( - loc, getPrintfType(context), printfRef, - ArrayRef({formatSpecifierCst, elementLoad})); - - // Notify the rewriter that this operation has been removed. - rewriter.eraseOp(op); - return success(); - } - -private: - /// Create a function declaration for printf, the signature is: - /// * `i32 (i8*, ...)` - static LLVM::LLVMFunctionType getPrintfType(MLIRContext *context) { - auto llvmI32Ty = IntegerType::get(context, 32); - auto llvmPtrTy = LLVM::LLVMPointerType::get(context); - auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, - /*isVarArg=*/true); - return llvmFnType; - } - - /// Return a symbol reference to the printf function, inserting it into the - /// module if necessary. - static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, - ModuleOp module) { - auto *context = module.getContext(); - if (module.lookupSymbol("printf")) - return SymbolRefAttr::get(context, "printf"); - - // Insert the printf function into the body of the parent module. - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), "printf", - getPrintfType(context)); - return SymbolRefAttr::get(context, "printf"); - } - - /// Return a value representing an access into a global string with the given - /// name, creating the string if necessary. - static Value getOrCreateGlobalString(Location loc, OpBuilder &builder, - StringRef name, StringRef value, - ModuleOp module) { - // Create the global at the entry of the module. - LLVM::GlobalOp global; - if (!(global = module.lookupSymbol(name))) { - OpBuilder::InsertionGuard insertGuard(builder); - builder.setInsertionPointToStart(module.getBody()); - auto type = LLVM::LLVMArrayType::get( - IntegerType::get(builder.getContext(), 8), value.size()); - global = builder.create(loc, type, /*isConstant=*/true, - LLVM::Linkage::Internal, name, - builder.getStringAttr(value), - /*alignment=*/0); - } - - // Get the pointer to the first character in the global string. - Value globalPtr = builder.create(loc, global); - Value cst0 = builder.create(loc, builder.getI64Type(), - builder.getIndexAttr(0)); - return builder.create( - loc, LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), - globalPtr, ArrayRef({cst0, cst0})); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// ToyToLLVMLoweringPass -//===----------------------------------------------------------------------===// - -namespace { -struct ToyToLLVMLoweringPass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ToyToLLVMLoweringPass) - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - void runOnOperation() final; -}; -} // namespace - -void ToyToLLVMLoweringPass::runOnOperation() { - // The first thing to define is the conversion target. This will define the - // final target for this lowering. For this lowering, we are only targeting - // the LLVM dialect. - LLVMConversionTarget target(getContext()); - target.addLegalOp(); - - // During this lowering, we will also be lowering the MemRef types, that are - // currently being operated on, to a representation in LLVM. To perform this - // conversion we use a TypeConverter as part of the lowering. This converter - // details how one type maps to another. This is necessary now that we will be - // doing more complicated lowerings, involving loop region arguments. - LLVMTypeConverter typeConverter(&getContext()); - - // Now that the conversion target has been defined, we need to provide the - // patterns used for lowering. At this point of the compilation process, we - // have a combination of `toy`, `affine`, and `std` operations. Luckily, there - // are already exists a set of patterns to transform `affine` and `std` - // dialects. These patterns lowering in multiple stages, relying on transitive - // lowerings. Transitive lowering, or A->B->C lowering, is when multiple - // patterns must be applied to fully transform an illegal operation into a - // set of legal ones. - RewritePatternSet patterns(&getContext()); - populateAffineToStdConversionPatterns(patterns); - populateSCFToControlFlowConversionPatterns(patterns); - mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); - populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); - cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); - populateFuncToLLVMConversionPatterns(typeConverter, patterns); - - // The only remaining operation to lower from the `toy` dialect, is the - // PrintOp. - patterns.add(&getContext()); - - // We want to completely lower to LLVM, so we use a `FullConversion`. This - // ensures that only legal operations will remain after the conversion. - auto module = getOperation(); - if (failed(applyFullConversion(module, target, std::move(patterns)))) - signalPassFailure(); -} - -/// Create a pass for lowering operations the remaining `Toy` operations, as -/// well as `Affine` and `Std`, to the LLVM dialect for codegen. -std::unique_ptr mlir::toy::createLowerToLLVMPass() { - return std::make_unique(); -} diff --git a/mlir/example/Ch8/mlir/MLIRGen.cpp b/mlir/example/Ch8/mlir/MLIRGen.cpp deleted file mode 100644 index 46dc8e0..0000000 --- a/mlir/example/Ch8/mlir/MLIRGen.cpp +++ /dev/null @@ -1,699 +0,0 @@ -//===- MLIRGen.cpp - MLIR Generation from a Toy AST -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a simple IR generation targeting MLIR from a Module AST -// for the Toy language. -// -//===----------------------------------------------------------------------===// - -#include "toy/MLIRGen.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Value.h" -#include "toy/AST.h" -#include "toy/Dialect.h" - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Verifier.h" -#include "toy/Lexer.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Support/ErrorHandling.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace mlir::toy; -using namespace toy; - -using llvm::ArrayRef; -using llvm::cast; -using llvm::dyn_cast; -using llvm::isa; -using llvm::ScopedHashTableScope; -using llvm::SmallVector; -using llvm::StringRef; -using llvm::Twine; - -namespace { - -/// Implementation of a simple MLIR emission from the Toy AST. -/// -/// This will emit operations that are specific to the Toy language, preserving -/// the semantics of the language and (hopefully) allow to perform accurate -/// analysis and transformation based on these high level semantics. -class MLIRGenImpl { -public: - MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} - - /// Public API: convert the AST for a Toy module (source file) to an MLIR - /// Module operation. - mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { - // We create an empty MLIR module and codegen functions one at a time and - // add them to the module. - theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); - - for (auto &record : moduleAST) { - if (FunctionAST *funcAST = llvm::dyn_cast(record.get())) { - mlir::toy::FuncOp func = mlirGen(*funcAST); - if (!func) - return nullptr; - functionMap.insert({func.getName(), func}); - } else if (StructAST *str = llvm::dyn_cast(record.get())) { - if (failed(mlirGen(*str))) - return nullptr; - } else { - llvm_unreachable("unknown record type"); - } - } - - // Verify the module after we have finished constructing it, this will check - // the structural properties of the IR and invoke any specific verifiers we - // have on the Toy operations. - if (failed(mlir::verify(theModule))) { - theModule.emitError("module verification error"); - return nullptr; - } - - return theModule; - } - -private: - /// A "module" matches a Toy source file: containing a list of functions. - mlir::ModuleOp theModule; - - /// The builder is a helper class to create IR inside a function. The builder - /// is stateful, in particular it keeps an "insertion point": this is where - /// the next operations will be introduced. - mlir::OpBuilder builder; - - /// The symbol table maps a variable name to a value in the current scope. - /// Entering a function creates a new scope, and the function arguments are - /// added to the mapping. When the processing of a function is terminated, the - /// scope is destroyed and the mappings created in this scope are dropped. - llvm::ScopedHashTable> - symbolTable; - using SymbolTableScopeT = - llvm::ScopedHashTableScope>; - - /// A mapping for the functions that have been code generated to MLIR. - llvm::StringMap functionMap; - - /// A mapping for named struct types to the underlying MLIR type and the - /// original AST node. - llvm::StringMap> structMap; - - /// Helper conversion for a Toy AST location to an MLIR location. - mlir::Location loc(const Location &loc) { - return mlir::FileLineColLoc::get(builder.getStringAttr(*loc.file), loc.line, - loc.col); - } - - /// Declare a variable in the current scope, return success if the variable - /// wasn't declared yet. - llvm::LogicalResult declare(VarDeclExprAST &var, mlir::Value value) { - if (symbolTable.count(var.getName())) - return mlir::failure(); - symbolTable.insert(var.getName(), {value, &var}); - return mlir::success(); - } - - /// Create an MLIR type for the given struct. - llvm::LogicalResult mlirGen(StructAST &str) { - if (structMap.count(str.getName())) - return emitError(loc(str.loc())) << "error: struct type with name `" - << str.getName() << "' already exists"; - - auto variables = str.getVariables(); - std::vector elementTypes; - elementTypes.reserve(variables.size()); - for (auto &variable : variables) { - if (variable->getInitVal()) - return emitError(loc(variable->loc())) - << "error: variables within a struct definition must not have " - "initializers"; - if (!variable->getType().shape.empty()) - return emitError(loc(variable->loc())) - << "error: variables within a struct definition must not have " - "initializers"; - - mlir::Type type = getType(variable->getType(), variable->loc()); - if (!type) - return mlir::failure(); - elementTypes.push_back(type); - } - - structMap.try_emplace(str.getName(), StructType::get(elementTypes), &str); - return mlir::success(); - } - - /// Create the prototype for an MLIR function with as many arguments as the - /// provided Toy AST prototype. - mlir::toy::FuncOp mlirGen(PrototypeAST &proto) { - auto location = loc(proto.loc()); - - // This is a generic function, the return type will be inferred later. - llvm::SmallVector argTypes; - argTypes.reserve(proto.getArgs().size()); - for (auto &arg : proto.getArgs()) { - mlir::Type type = getType(arg->getType(), arg->loc()); - if (!type) - return nullptr; - argTypes.push_back(type); - } - auto funcType = builder.getFunctionType(argTypes, std::nullopt); - return builder.create(location, proto.getName(), - funcType); - } - - /// Emit a new function and add it to the MLIR module. - mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { - // Create a scope in the symbol table to hold variable declarations. - SymbolTableScopeT varScope(symbolTable); - - // Create an MLIR function for the given prototype. - builder.setInsertionPointToEnd(theModule.getBody()); - mlir::toy::FuncOp function = mlirGen(*funcAST.getProto()); - if (!function) - return nullptr; - - // Let's start the body of the function now! - mlir::Block &entryBlock = function.front(); - auto protoArgs = funcAST.getProto()->getArgs(); - - // Declare all the function arguments in the symbol table. - for (const auto nameValue : - llvm::zip(protoArgs, entryBlock.getArguments())) { - if (failed(declare(*std::get<0>(nameValue), std::get<1>(nameValue)))) - return nullptr; - } - - // Set the insertion point in the builder to the beginning of the function - // body, it will be used throughout the codegen to create operations in this - // function. - builder.setInsertionPointToStart(&entryBlock); - - // Emit the body of the function. - if (mlir::failed(mlirGen(*funcAST.getBody()))) { - function.erase(); - return nullptr; - } - - // Implicitly return void if no return statement was emitted. - // FIXME: we may fix the parser instead to always return the last expression - // (this would possibly help the REPL case later) - ReturnOp returnOp; - if (!entryBlock.empty()) - returnOp = dyn_cast(entryBlock.back()); - if (!returnOp) { - builder.create(loc(funcAST.getProto()->loc())); - } else if (returnOp.hasOperand()) { - // Otherwise, if this return operation has an operand then add a result to - // the function. - function.setType( - builder.getFunctionType(function.getFunctionType().getInputs(), - *returnOp.operand_type_begin())); - } - - // If this function isn't main, then set the visibility to private. - if (funcAST.getProto()->getName() != "main") - function.setPrivate(); - - return function; - } - - /// Return the struct type that is the result of the given expression, or null - /// if it cannot be inferred. - StructAST *getStructFor(ExprAST *expr) { - llvm::StringRef structName; - if (auto *decl = llvm::dyn_cast(expr)) { - auto varIt = symbolTable.lookup(decl->getName()); - if (!varIt.first) - return nullptr; - structName = varIt.second->getType().name; - } else if (auto *access = llvm::dyn_cast(expr)) { - if (access->getOp() != '.') - return nullptr; - // The name being accessed should be in the RHS. - auto *name = llvm::dyn_cast(access->getRHS()); - if (!name) - return nullptr; - StructAST *parentStruct = getStructFor(access->getLHS()); - if (!parentStruct) - return nullptr; - - // Get the element within the struct corresponding to the name. - VarDeclExprAST *decl = nullptr; - for (auto &var : parentStruct->getVariables()) { - if (var->getName() == name->getName()) { - decl = var.get(); - break; - } - } - if (!decl) - return nullptr; - structName = decl->getType().name; - } - if (structName.empty()) - return nullptr; - - // If the struct name was valid, check for an entry in the struct map. - auto structIt = structMap.find(structName); - if (structIt == structMap.end()) - return nullptr; - return structIt->second.second; - } - - /// Return the numeric member index of the given struct access expression. - std::optional getMemberIndex(BinaryExprAST &accessOp) { - assert(accessOp.getOp() == '.' && "expected access operation"); - - // Lookup the struct node for the LHS. - StructAST *structAST = getStructFor(accessOp.getLHS()); - if (!structAST) - return std::nullopt; - - // Get the name from the RHS. - VariableExprAST *name = llvm::dyn_cast(accessOp.getRHS()); - if (!name) - return std::nullopt; - - auto structVars = structAST->getVariables(); - const auto *it = llvm::find_if(structVars, [&](auto &var) { - return var->getName() == name->getName(); - }); - if (it == structVars.end()) - return std::nullopt; - return it - structVars.begin(); - } - - /// Emit a binary operation - mlir::Value mlirGen(BinaryExprAST &binop) { - // First emit the operations for each side of the operation before emitting - // the operation itself. For example if the expression is `a + foo(a)` - // 1) First it will visiting the LHS, which will return a reference to the - // value holding `a`. This value should have been emitted at declaration - // time and registered in the symbol table, so nothing would be - // codegen'd. If the value is not in the symbol table, an error has been - // emitted and nullptr is returned. - // 2) Then the RHS is visited (recursively) and a call to `foo` is emitted - // and the result value is returned. If an error occurs we get a nullptr - // and propagate. - // - mlir::Value lhs = mlirGen(*binop.getLHS()); - if (!lhs) - return nullptr; - auto location = loc(binop.loc()); - - // If this is an access operation, handle it immediately. - if (binop.getOp() == '.') { - std::optional accessIndex = getMemberIndex(binop); - if (!accessIndex) { - emitError(location, "invalid access into struct expression"); - return nullptr; - } - return builder.create(location, lhs, *accessIndex); - } - - // Otherwise, this is a normal binary op. - mlir::Value rhs = mlirGen(*binop.getRHS()); - if (!rhs) - return nullptr; - - // Derive the operation name from the binary operator. At the moment we only - // support '+' and '*'. - switch (binop.getOp()) { - case '+': - return builder.create(location, lhs, rhs); - case '*': - return builder.create(location, lhs, rhs); - } - - emitError(location, "invalid binary operator '") << binop.getOp() << "'"; - return nullptr; - } - - /// This is a reference to a variable in an expression. The variable is - /// expected to have been declared and so should have a value in the symbol - /// table, otherwise emit an error and return nullptr. - mlir::Value mlirGen(VariableExprAST &expr) { - if (auto variable = symbolTable.lookup(expr.getName()).first) - return variable; - - emitError(loc(expr.loc()), "error: unknown variable '") - << expr.getName() << "'"; - return nullptr; - } - - /// Emit a return operation. This will return failure if any generation fails. - llvm::LogicalResult mlirGen(ReturnExprAST &ret) { - auto location = loc(ret.loc()); - - // 'return' takes an optional expression, handle that case here. - mlir::Value expr = nullptr; - if (ret.getExpr().has_value()) { - if (!(expr = mlirGen(**ret.getExpr()))) - return mlir::failure(); - } - - // Otherwise, this return operation has zero operands. - builder.create(location, - expr ? ArrayRef(expr) : ArrayRef()); - return mlir::success(); - } - - /// Emit a constant for a literal/constant array. It will be emitted as a - /// flattened array of data in an Attribute attached to a `toy.constant` - /// operation. See documentation on [Attributes](LangRef.md#attributes) for - /// more details. Here is an excerpt: - /// - /// Attributes are the mechanism for specifying constant data in MLIR in - /// places where a variable is never allowed [...]. They consist of a name - /// and a concrete attribute value. The set of expected attributes, their - /// structure, and their interpretation are all contextually dependent on - /// what they are attached to. - /// - /// Example, the source level statement: - /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; - /// will be converted to: - /// %0 = "toy.constant"() {value: dense, - /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], - /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> - /// - mlir::DenseElementsAttr getConstantAttr(LiteralExprAST &lit) { - // The attribute is a vector with a floating point value per element - // (number) in the array, see `collectData()` below for more details. - std::vector data; - data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, - std::multiplies())); - collectData(lit, data); - - // The type of this attribute is tensor of 64-bit floating-point with the - // shape of the literal. - mlir::Type elementType = builder.getF64Type(); - auto dataType = mlir::RankedTensorType::get(lit.getDims(), elementType); - - // This is the actual attribute that holds the list of values for this - // tensor literal. - return mlir::DenseElementsAttr::get(dataType, llvm::ArrayRef(data)); - } - mlir::DenseElementsAttr getConstantAttr(NumberExprAST &lit) { - // The type of this attribute is tensor of 64-bit floating-point with no - // shape. - mlir::Type elementType = builder.getF64Type(); - auto dataType = mlir::RankedTensorType::get({}, elementType); - - // This is the actual attribute that holds the list of values for this - // tensor literal. - return mlir::DenseElementsAttr::get(dataType, - llvm::ArrayRef(lit.getValue())); - } - /// Emit a constant for a struct literal. It will be emitted as an array of - /// other literals in an Attribute attached to a `toy.struct_constant` - /// operation. This function returns the generated constant, along with the - /// corresponding struct type. - std::pair - getConstantAttr(StructLiteralExprAST &lit) { - std::vector attrElements; - std::vector typeElements; - - for (auto &var : lit.getValues()) { - if (auto *number = llvm::dyn_cast(var.get())) { - attrElements.push_back(getConstantAttr(*number)); - typeElements.push_back(getType(std::nullopt)); - } else if (auto *lit = llvm::dyn_cast(var.get())) { - attrElements.push_back(getConstantAttr(*lit)); - typeElements.push_back(getType(std::nullopt)); - } else { - auto *structLit = llvm::cast(var.get()); - auto attrTypePair = getConstantAttr(*structLit); - attrElements.push_back(attrTypePair.first); - typeElements.push_back(attrTypePair.second); - } - } - mlir::ArrayAttr dataAttr = builder.getArrayAttr(attrElements); - mlir::Type dataType = StructType::get(typeElements); - return std::make_pair(dataAttr, dataType); - } - - /// Emit an array literal. - mlir::Value mlirGen(LiteralExprAST &lit) { - mlir::Type type = getType(lit.getDims()); - mlir::DenseElementsAttr dataAttribute = getConstantAttr(lit); - - // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` - // method. - return builder.create(loc(lit.loc()), type, dataAttribute); - } - - /// Emit a struct literal. It will be emitted as an array of - /// other literals in an Attribute attached to a `toy.struct_constant` - /// operation. - mlir::Value mlirGen(StructLiteralExprAST &lit) { - mlir::ArrayAttr dataAttr; - mlir::Type dataType; - std::tie(dataAttr, dataType) = getConstantAttr(lit); - - // Build the MLIR op `toy.struct_constant`. This invokes the - // `StructConstantOp::build` method. - return builder.create(loc(lit.loc()), dataType, dataAttr); - } - - /// Recursive helper function to accumulate the data that compose an array - /// literal. It flattens the nested structure in the supplied vector. For - /// example with this array: - /// [[1, 2], [3, 4]] - /// we will generate: - /// [ 1, 2, 3, 4 ] - /// Individual numbers are represented as doubles. - /// Attributes are the way MLIR attaches constant to operations. - void collectData(ExprAST &expr, std::vector &data) { - if (auto *lit = dyn_cast(&expr)) { - for (auto &value : lit->getValues()) - collectData(*value, data); - return; - } - - assert(isa(expr) && "expected literal or number expr"); - data.push_back(cast(expr).getValue()); - } - - /// Emit a call expression. It emits specific operations for the `transpose` - /// builtin. Other identifiers are assumed to be user-defined functions. - mlir::Value mlirGen(CallExprAST &call) { - llvm::StringRef callee = call.getCallee(); - auto location = loc(call.loc()); - - // Codegen the operands first. - SmallVector operands; - for (auto &expr : call.getArgs()) { - auto arg = mlirGen(*expr); - if (!arg) - return nullptr; - operands.push_back(arg); - } - - // Builtin calls have their custom operation, meaning this is a - // straightforward emission. - if (callee == "transpose") { - if (call.getArgs().size() != 1) { - emitError(location, "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); - return nullptr; - } - return builder.create(location, operands[0]); - } - - if (callee == "matmul") { - if (call.getArgs().size() != 2) { - emitError(location, "MLIR codegen encountered an error: toy.matmul " - "expected 2 arguments"); - } - return builder.create(location, operands[0], operands[1]); - } - - // Otherwise this is a call to a user-defined function. Calls to - // user-defined functions are mapped to a custom call that takes the callee - // name as an attribute. - auto calledFuncIt = functionMap.find(callee); - if (calledFuncIt == functionMap.end()) { - emitError(location) << "no defined function found for '" << callee << "'"; - return nullptr; - } - mlir::toy::FuncOp calledFunc = calledFuncIt->second; - return builder.create( - location, calledFunc.getFunctionType().getResult(0), - mlir::SymbolRefAttr::get(builder.getContext(), callee), operands); - } - - /// Emit a print expression. It emits specific operations for two builtins: - /// transpose(x) and print(x). - llvm::LogicalResult mlirGen(PrintExprAST &call) { - auto arg = mlirGen(*call.getArg()); - if (!arg) - return mlir::failure(); - - builder.create(loc(call.loc()), arg); - return mlir::success(); - } - - /// Emit a constant for a single number (FIXME: semantic? broadcast?) - mlir::Value mlirGen(NumberExprAST &num) { - return builder.create(loc(num.loc()), num.getValue()); - } - - /// Dispatch codegen for the right expression subclass using RTTI. - mlir::Value mlirGen(ExprAST &expr) { - switch (expr.getKind()) { - case toy::ExprAST::Expr_BinOp: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Var: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Literal: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_StructLiteral: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Call: - return mlirGen(cast(expr)); - case toy::ExprAST::Expr_Num: - return mlirGen(cast(expr)); - default: - emitError(loc(expr.loc())) - << "MLIR codegen encountered an unhandled expr kind '" - << Twine(expr.getKind()) << "'"; - return nullptr; - } - } - - /// Handle a variable declaration, we'll codegen the expression that forms the - /// initializer and record the value in the symbol table before returning it. - /// Future expressions will be able to reference this variable through symbol - /// table lookup. - mlir::Value mlirGen(VarDeclExprAST &vardecl) { - auto *init = vardecl.getInitVal(); - if (!init) { - emitError(loc(vardecl.loc()), - "missing initializer in variable declaration"); - return nullptr; - } - - mlir::Value value = mlirGen(*init); - if (!value) - return nullptr; - - // Handle the case where we are initializing a struct value. - VarType varType = vardecl.getType(); - if (!varType.name.empty()) { - // Check that the initializer type is the same as the variable - // declaration. - mlir::Type type = getType(varType, vardecl.loc()); - if (!type) - return nullptr; - if (type != value.getType()) { - emitError(loc(vardecl.loc())) - << "struct type of initializer is different than the variable " - "declaration. Got " - << value.getType() << ", but expected " << type; - return nullptr; - } - - // Otherwise, we have the initializer value, but in case the variable was - // declared with specific shape, we emit a "reshape" operation. It will - // get optimized out later as needed. - } else if (!varType.shape.empty()) { - value = builder.create(loc(vardecl.loc()), - getType(varType.shape), value); - } - - // Register the value in the symbol table. - if (failed(declare(vardecl, value))) - return nullptr; - return value; - } - - /// Codegen a list of expression, return failure if one of them hit an error. - llvm::LogicalResult mlirGen(ExprASTList &blockAST) { - SymbolTableScopeT varScope(symbolTable); - for (auto &expr : blockAST) { - // Specific handling for variable declarations, return statement, and - // print. These can only appear in block list and not in nested - // expressions. - if (auto *vardecl = dyn_cast(expr.get())) { - if (!mlirGen(*vardecl)) - return mlir::failure(); - continue; - } - if (auto *ret = dyn_cast(expr.get())) - return mlirGen(*ret); - if (auto *print = dyn_cast(expr.get())) { - if (mlir::failed(mlirGen(*print))) - return mlir::success(); - continue; - } - - // Generic expression dispatch codegen. - if (!mlirGen(*expr)) - return mlir::failure(); - } - return mlir::success(); - } - - /// Build a tensor type from a list of shape dimensions. - mlir::Type getType(ArrayRef shape) { - // If the shape is empty, then this type is unranked. - if (shape.empty()) - return mlir::UnrankedTensorType::get(builder.getF64Type()); - - // Otherwise, we use the given shape. - return mlir::RankedTensorType::get(shape, builder.getF64Type()); - } - - /// Build an MLIR type from a Toy AST variable type (forward to the generic - /// getType above for non-struct types). - mlir::Type getType(const VarType &type, const Location &location) { - if (!type.name.empty()) { - auto it = structMap.find(type.name); - if (it == structMap.end()) { - emitError(loc(location)) - << "error: unknown struct type '" << type.name << "'"; - return nullptr; - } - return it->second.first; - } - - return getType(type.shape); - } -}; - -} // namespace - -namespace toy { - -// The public API for codegen. -mlir::OwningOpRef mlirGen(mlir::MLIRContext &context, - ModuleAST &moduleAST) { - return MLIRGenImpl(context).mlirGen(moduleAST); -} - -} // namespace toy diff --git a/mlir/example/Ch8/mlir/ShapeInferencePass.cpp b/mlir/example/Ch8/mlir/ShapeInferencePass.cpp deleted file mode 100644 index a9e995e..0000000 --- a/mlir/example/Ch8/mlir/ShapeInferencePass.cpp +++ /dev/null @@ -1,122 +0,0 @@ -//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a Function level pass performing interprocedural -// propagation of array shapes through function specialization. -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Types.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/TypeID.h" -#include "toy/Dialect.h" -#include "toy/Passes.h" -#include "toy/ShapeInferenceInterface.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include - -#define DEBUG_TYPE "shape-inference" - -using namespace mlir; -using namespace toy; - -/// Include the auto-generated definitions for the shape inference interfaces. -#include "toy/ShapeInferenceOpInterfaces.cpp.inc" - -namespace { -/// The ShapeInferencePass is a pass that performs intra-procedural -/// shape inference. -/// -/// Algorithm: -/// -/// 1) Build a worklist containing all the operations that return a -/// dynamically shaped tensor: these are the operations that need shape -/// inference. -/// 2) Iterate on the worklist: -/// a) find an operation to process: the next ready operation in the -/// worklist has all of its arguments non-generic, -/// b) if no operation is found, break out of the loop, -/// c) remove the operation from the worklist, -/// d) infer the shape of its output from the argument types. -/// 3) If the worklist is empty, the algorithm succeeded. -/// -struct ShapeInferencePass - : public mlir::PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ShapeInferencePass) - - void runOnOperation() override { - auto f = getOperation(); - - // Populate the worklist with the operations that need shape inference: - // these are operations that return a dynamic shape. - llvm::SmallPtrSet opWorklist; - f.walk([&](mlir::Operation *op) { - if (returnsDynamicShape(op)) - opWorklist.insert(op); - }); - - // Iterate on the operations in the worklist until all operations have been - // inferred or no change happened (fix point). - while (!opWorklist.empty()) { - // Find the next operation ready for inference, that is an operation - // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(opWorklist, allOperandsInferred); - if (nextop == opWorklist.end()) - break; - - Operation *op = *nextop; - opWorklist.erase(op); - - // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); - if (auto shapeOp = dyn_cast(op)) { - shapeOp.inferShapes(); - } else { - op->emitError("unable to infer shape of operation without shape " - "inference interface"); - return signalPassFailure(); - } - } - - // If the operation worklist isn't empty, this indicates a failure. - if (!opWorklist.empty()) { - f.emitError("Shape inference failed, ") - << opWorklist.size() << " operations couldn't be inferred\n"; - signalPassFailure(); - } - } - - /// A utility method that returns if the given operation has all of its - /// operands inferred. - static bool allOperandsInferred(Operation *op) { - return llvm::all_of(op->getOperandTypes(), [](Type operandType) { - return llvm::isa(operandType); - }); - } - - /// A utility method that returns if the given operation has a dynamically - /// shaped result. - static bool returnsDynamicShape(Operation *op) { - return llvm::any_of(op->getResultTypes(), [](Type resultType) { - return !llvm::isa(resultType); - }); - } -}; -} // namespace - -/// Create a Shape Inference pass. -std::unique_ptr mlir::toy::createShapeInferencePass() { - return std::make_unique(); -} diff --git a/mlir/example/Ch8/mlir/ToyCombine.cpp b/mlir/example/Ch8/mlir/ToyCombine.cpp deleted file mode 100644 index 1d8cf74..0000000 --- a/mlir/example/Ch8/mlir/ToyCombine.cpp +++ /dev/null @@ -1,89 +0,0 @@ -//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a set of simple combiners for optimizing operations in -// the Toy dialect. -// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "toy/Dialect.h" -#include "llvm/Support/Casting.h" -#include -using namespace mlir; -using namespace toy; - -namespace { -/// Include the patterns defined in the Declarative Rewrite framework. -#include "ToyCombine.inc" -} // namespace - -/// Fold constants. -OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } - -/// Fold struct constants. -OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } - -/// Fold simple struct access operations that access into a constant. -OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { - auto structAttr = - llvm::dyn_cast_if_present(adaptor.getInput()); - if (!structAttr) - return nullptr; - - size_t elementIndex = getIndex(); - return structAttr[elementIndex]; -} - -/// This is an example of a c++ rewrite pattern for the TransposeOp. It -/// optimizes the following scenario: transpose(transpose(x)) -> x -struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { - /// We register this pattern to match every toy.transpose in the IR. - /// The "benefit" is used by the framework to order the patterns and process - /// them in order of profitability. - SimplifyRedundantTranspose(mlir::MLIRContext *context) - : OpRewritePattern(context, /*benefit=*/1) {} - - /// This method attempts to match a pattern and rewrite it. The rewriter - /// argument is the orchestrator of the sequence of rewrites. The pattern is - /// expected to interact with it to perform any changes to the IR from here. - llvm::LogicalResult - matchAndRewrite(TransposeOp op, - mlir::PatternRewriter &rewriter) const override { - // Look through the input of the current transpose. - mlir::Value transposeInput = op.getOperand(); - TransposeOp transposeInputOp = transposeInput.getDefiningOp(); - - // Input defined by another transpose? If not, no match. - if (!transposeInputOp) - return failure(); - - // Otherwise, we have a redundant transpose. Use the rewriter. - rewriter.replaceOp(op, {transposeInputOp.getOperand()}); - return success(); - } -}; - -/// Register our patterns as "canonicalization" patterns on the TransposeOp so -/// that they can be picked up by the Canonicalization framework. -void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -/// Register our patterns as "canonicalization" patterns on the ReshapeOp so -/// that they can be picked up by the Canonicalization framework. -void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} diff --git a/mlir/example/Ch8/mlir/ToyCombine.td b/mlir/example/Ch8/mlir/ToyCombine.td deleted file mode 100644 index 11d7831..0000000 --- a/mlir/example/Ch8/mlir/ToyCombine.td +++ /dev/null @@ -1,63 +0,0 @@ -//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Defines language-specific pattern match optimizations for Toy using -// Declarative Rewrite Rules (DRR) specified using TableGen records. -// -//===----------------------------------------------------------------------===// - -#ifndef TOY_COMBINE -#define TOY_COMBINE - -include "mlir/IR/PatternBase.td" -include "toy/Ops.td" - -/// Note: The DRR definition used for defining patterns is shown below: -/// -/// class Pattern< -/// dag sourcePattern, list resultPatterns, -/// list additionalConstraints = [], -/// dag benefitsAdded = (addBenefit 0) -/// >; - -//===----------------------------------------------------------------------===// -// Basic Pattern-Match and Rewrite -//===----------------------------------------------------------------------===// - -// Reshape(Reshape(x)) = Reshape(x) -def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), - (ReshapeOp $arg)>; - -//===----------------------------------------------------------------------===// -// Pattern-Match and Rewrite using Native Code Call -//===----------------------------------------------------------------------===// - -// Native Code Calls may be used for more complex transformations using inline -// C++ and C++ helper functions. - -// Reshape(Constant(x)) = x' -def ReshapeConstant : - NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; -def FoldConstantReshapeOptPattern : Pat< - (ReshapeOp:$res (ConstantOp $arg)), - (ConstantOp (ReshapeConstant $arg, $res))>; - -//===----------------------------------------------------------------------===// -// Pattern-Match and Rewrite with Constraints -//===----------------------------------------------------------------------===// - -// DRR allows for constraint checking when the transformation is conditional -// on operand properties. - -// Reshape(x) = x, where input and output shapes are identical -def TypesAreIdentical : Constraint>; -def RedundantReshapeOptPattern : Pat< - (ReshapeOp:$res $arg), (replaceWithValue $arg), - [(TypesAreIdentical $res, $arg)]>; - -#endif // TOY_COMBINE diff --git a/mlir/example/Ch8/parser/AST.cpp b/mlir/example/Ch8/parser/AST.cpp deleted file mode 100644 index e38a743..0000000 --- a/mlir/example/Ch8/parser/AST.cpp +++ /dev/null @@ -1,274 +0,0 @@ -//===- AST.cpp - Helper for printing out the Toy AST ----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the AST dump for the Toy language. -// -//===----------------------------------------------------------------------===// - -#include "toy/AST.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Twine.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/raw_ostream.h" -#include - -using namespace toy; - -namespace { - -// RAII helper to manage increasing/decreasing the indentation as we traverse -// the AST -struct Indent { - Indent(int &level) : level(level) { ++level; } - ~Indent() { --level; } - int &level; -}; - -/// Helper class that implement the AST tree traversal and print the nodes along -/// the way. The only data member is the current indentation level. -class ASTDumper { -public: - void dump(ModuleAST *node); - -private: - void dump(const VarType &type); - void dump(VarDeclExprAST *varDecl); - void dump(ExprAST *expr); - void dump(ExprASTList *exprList); - void dump(NumberExprAST *num); - void dump(LiteralExprAST *node); - void dump(StructLiteralExprAST *node); - void dump(VariableExprAST *node); - void dump(ReturnExprAST *node); - void dump(BinaryExprAST *node); - void dump(CallExprAST *node); - void dump(PrintExprAST *node); - void dump(PrototypeAST *node); - void dump(FunctionAST *node); - void dump(StructAST *node); - - // Actually print spaces matching the current indentation level - void indent() { - for (int i = 0; i < curIndent; i++) - llvm::errs() << " "; - } - int curIndent = 0; -}; - -} // namespace - -/// Return a formatted string for the location of any node -template -static std::string loc(T *node) { - const auto &loc = node->loc(); - return (llvm::Twine("@") + *loc.file + ":" + llvm::Twine(loc.line) + ":" + - llvm::Twine(loc.col)) - .str(); -} - -// Helper Macro to bump the indentation level and print the leading spaces for -// the current indentations -#define INDENT() \ - Indent level_(curIndent); \ - indent(); - -/// Dispatch to a generic expressions to the appropriate subclass using RTTI -void ASTDumper::dump(ExprAST *expr) { - llvm::TypeSwitch(expr) - .Case([&](auto *node) { this->dump(node); }) - .Default([&](ExprAST *) { - // No match, fallback to a generic message - INDENT(); - llvm::errs() << "getKind() << ">\n"; - }); -} - -/// A variable declaration is printing the variable name, the type, and then -/// recurse in the initializer value. -void ASTDumper::dump(VarDeclExprAST *varDecl) { - INDENT(); - llvm::errs() << "VarDecl " << varDecl->getName(); - dump(varDecl->getType()); - llvm::errs() << " " << loc(varDecl) << "\n"; - if (auto *initVal = varDecl->getInitVal()) - dump(initVal); -} - -/// A "block", or a list of expression -void ASTDumper::dump(ExprASTList *exprList) { - INDENT(); - llvm::errs() << "Block {\n"; - for (auto &expr : *exprList) - dump(expr.get()); - indent(); - llvm::errs() << "} // Block\n"; -} - -/// A literal number, just print the value. -void ASTDumper::dump(NumberExprAST *num) { - INDENT(); - llvm::errs() << num->getValue() << " " << loc(num) << "\n"; -} - -/// Helper to print recursively a literal. This handles nested array like: -/// [ [ 1, 2 ], [ 3, 4 ] ] -/// We print out such array with the dimensions spelled out at every level: -/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] -void printLitHelper(ExprAST *litOrNum) { - // Inside a literal expression we can have either a number or another literal - if (auto *num = llvm::dyn_cast(litOrNum)) { - llvm::errs() << num->getValue(); - return; - } - auto *literal = llvm::cast(litOrNum); - - // Print the dimension for this literal first - llvm::errs() << "<"; - llvm::interleaveComma(literal->getDims(), llvm::errs()); - llvm::errs() << ">"; - - // Now print the content, recursing on every element of the list - llvm::errs() << "[ "; - llvm::interleaveComma(literal->getValues(), llvm::errs(), - [&](auto &elt) { printLitHelper(elt.get()); }); - llvm::errs() << "]"; -} - -/// Print a literal, see the recursive helper above for the implementation. -void ASTDumper::dump(LiteralExprAST *node) { - INDENT(); - llvm::errs() << "Literal: "; - printLitHelper(node); - llvm::errs() << " " << loc(node) << "\n"; -} - -/// Print a struct literal. -void ASTDumper::dump(StructLiteralExprAST *node) { - INDENT(); - llvm::errs() << "Struct Literal: "; - for (auto &value : node->getValues()) - dump(value.get()); - indent(); - llvm::errs() << " " << loc(node) << "\n"; -} - -/// Print a variable reference (just a name). -void ASTDumper::dump(VariableExprAST *node) { - INDENT(); - llvm::errs() << "var: " << node->getName() << " " << loc(node) << "\n"; -} - -/// Return statement print the return and its (optional) argument. -void ASTDumper::dump(ReturnExprAST *node) { - INDENT(); - llvm::errs() << "Return\n"; - if (node->getExpr().has_value()) - return dump(*node->getExpr()); - { - INDENT(); - llvm::errs() << "(void)\n"; - } -} - -/// Print a binary operation, first the operator, then recurse into LHS and RHS. -void ASTDumper::dump(BinaryExprAST *node) { - INDENT(); - llvm::errs() << "BinOp: " << node->getOp() << " " << loc(node) << "\n"; - dump(node->getLHS()); - dump(node->getRHS()); -} - -/// Print a call expression, first the callee name and the list of args by -/// recursing into each individual argument. -void ASTDumper::dump(CallExprAST *node) { - INDENT(); - llvm::errs() << "Call '" << node->getCallee() << "' [ " << loc(node) << "\n"; - for (auto &arg : node->getArgs()) - dump(arg.get()); - indent(); - llvm::errs() << "]\n"; -} - -/// Print a builtin print call, first the builtin name and then the argument. -void ASTDumper::dump(PrintExprAST *node) { - INDENT(); - llvm::errs() << "Print [ " << loc(node) << "\n"; - dump(node->getArg()); - indent(); - llvm::errs() << "]\n"; -} - -/// Print type: only the shape is printed in between '<' and '>' -void ASTDumper::dump(const VarType &type) { - llvm::errs() << "<"; - if (!type.name.empty()) - llvm::errs() << type.name; - else - llvm::interleaveComma(type.shape, llvm::errs()); - llvm::errs() << ">"; -} - -/// Print a function prototype, first the function name, and then the list of -/// parameters names. -void ASTDumper::dump(PrototypeAST *node) { - INDENT(); - llvm::errs() << "Proto '" << node->getName() << "' " << loc(node) << "\n"; - indent(); - llvm::errs() << "Params: ["; - llvm::interleaveComma(node->getArgs(), llvm::errs(), - [](auto &arg) { llvm::errs() << arg->getName(); }); - llvm::errs() << "]\n"; -} - -/// Print a function, first the prototype and then the body. -void ASTDumper::dump(FunctionAST *node) { - INDENT(); - llvm::errs() << "Function \n"; - dump(node->getProto()); - dump(node->getBody()); -} - -/// Print a struct. -void ASTDumper::dump(StructAST *node) { - INDENT(); - llvm::errs() << "Struct: " << node->getName() << " " << loc(node) << "\n"; - - { - INDENT(); - llvm::errs() << "Variables: [\n"; - for (auto &variable : node->getVariables()) - dump(variable.get()); - indent(); - llvm::errs() << "]\n"; - } -} - -/// Print a module, actually loop over the functions and print them in sequence. -void ASTDumper::dump(ModuleAST *node) { - INDENT(); - llvm::errs() << "Module:\n"; - for (auto &record : *node) { - if (FunctionAST *function = llvm::dyn_cast(record.get())) - dump(function); - else if (StructAST *str = llvm::dyn_cast(record.get())) - dump(str); - else - llvm::errs() << "getKind() << ">\n"; - } -} - -namespace toy { - -// Public API -void dump(ModuleAST &module) { ASTDumper().dump(&module); } - -} // namespace toy diff --git a/mlir/example/Ch8/struct-codegen.toy b/mlir/example/Ch8/struct-codegen.toy deleted file mode 100644 index fa639c0..0000000 --- a/mlir/example/Ch8/struct-codegen.toy +++ /dev/null @@ -1,19 +0,0 @@ -struct Struct { - var a; - var b; -} - -# User defined generic function may operate on struct types as well. -def multiply_transpose(Struct value) { - # We can access the elements of a struct via the '.' operator. - return transpose(value.a) * transpose(value.b); -} - -def main() { - # We initialize struct values using a composite initializer. - Struct value = {[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]}; - - # We pass these arguments to functions like we do with variables. - var c = multiply_transpose(value); - print(c); -} diff --git a/mlir/example/Ch8/toyc.cpp b/mlir/example/Ch8/toyc.cpp deleted file mode 100644 index fea5679..0000000 --- a/mlir/example/Ch8/toyc.cpp +++ /dev/null @@ -1,329 +0,0 @@ -//===- toyc.cpp - The Toy Compiler ----------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the entry point for the Toy compiler. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/Extensions/AllExtensions.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "toy/AST.h" -#include "toy/Dialect.h" -#include "toy/Lexer.h" -#include "toy/MLIRGen.h" -#include "toy/Parser.h" -#include "toy/Passes.h" - -#include "mlir/Dialect/Affine/Passes.h" -#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" -#include "mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/ExecutionEngine/OptUtils.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Verifier.h" -#include "mlir/InitAllDialects.h" -#include "mlir/Parser/Parser.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Export.h" -#include "mlir/Transforms/Passes.h" - -#include "llvm/ADT/StringRef.h" -#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/ErrorOr.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/TargetSelect.h" -#include "llvm/Support/raw_ostream.h" -#include -#include -#include -#include -#include - -using namespace toy; -namespace cl = llvm::cl; - -static cl::opt inputFilename(cl::Positional, - cl::desc(""), - cl::init("-"), - cl::value_desc("filename")); - -namespace { -enum InputType { Toy, MLIR }; -} // namespace -static cl::opt inputType( - "x", cl::init(Toy), cl::desc("Decided the kind of output desired"), - cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")), - cl::values(clEnumValN(MLIR, "mlir", - "load the input file as an MLIR file"))); - -namespace { -enum Action { - None, - DumpAST, - DumpMLIR, - DumpMLIRAffine, - DumpMLIRLLVM, - DumpLLVMIR, - RunJIT -}; -} // namespace -static cl::opt emitAction( - "emit", cl::desc("Select the kind of output desired"), - cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), - cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")), - cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine", - "output the MLIR dump after affine lowering")), - cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", - "output the MLIR dump after llvm lowering")), - cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), - cl::values( - clEnumValN(RunJIT, "jit", - "JIT the code and run it by invoking the main function"))); - -static cl::opt enableOpt("opt", cl::desc("Enable optimizations")); - -/// Returns a Toy AST resulting from parsing the file or a nullptr on error. -std::unique_ptr parseInputFile(llvm::StringRef filename) { - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(filename); - if (std::error_code ec = fileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << ec.message() << "\n"; - return nullptr; - } - auto buffer = fileOrErr.get()->getBuffer(); - LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); - Parser parser(lexer); - return parser.parseModule(); -} - -int loadMLIR(mlir::MLIRContext &context, - mlir::OwningOpRef &module) { - // Handle '.toy' input to the compiler. - if (inputType != InputType::MLIR && - !llvm::StringRef(inputFilename).ends_with(".mlir")) { - auto moduleAST = parseInputFile(inputFilename); - if (!moduleAST) - return 6; - module = mlirGen(context, *moduleAST); - return !module ? 1 : 0; - } - - // Otherwise, the input is '.mlir'. - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); - if (std::error_code ec = fileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << ec.message() << "\n"; - return -1; - } - - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module = mlir::parseSourceFile(sourceMgr, &context); - if (!module) { - llvm::errs() << "Error can't load file " << inputFilename << "\n"; - return 3; - } - return 0; -} - -int loadAndProcessMLIR(mlir::MLIRContext &context, - mlir::OwningOpRef &module) { - if (int error = loadMLIR(context, module)) - return error; - - mlir::PassManager pm(module.get()->getName()); - // Apply any generic pass manager command line options and run the pipeline. - if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) - return 4; - - // Check to see what granularity of MLIR we are compiling to. - bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine; - bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; - - if (enableOpt || isLoweringToAffine) { - // Inline all functions into main and then delete them. - pm.addPass(mlir::createInlinerPass()); - - // Now that there is only one function, we can infer the shapes of each of - // the operations. - mlir::OpPassManager &optPM = pm.nest(); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::toy::createShapeInferencePass()); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::createCSEPass()); - } - - if (isLoweringToAffine) { - // Partially lower the toy dialect. - pm.addPass(mlir::toy::createLowerToAffinePass()); - - // Add a few cleanups post lowering. - mlir::OpPassManager &optPM = pm.nest(); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::createCSEPass()); - - // Add optimizations if enabled. - if (enableOpt) { - optPM.addPass(mlir::affine::createLoopFusionPass()); - optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); - } - } - - if (isLoweringToLLVM) { - // Finish lowering the toy IR to the LLVM dialect. - pm.addPass(mlir::toy::createLowerToLLVMPass()); - // This is necessary to have line tables emitted and basic - // debugger working. In the future we will add proper debug information - // emission directly from our frontend. - pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); - } - - if (mlir::failed(pm.run(*module))) - return 4; - return 0; -} - -int dumpAST() { - if (inputType == InputType::MLIR) { - llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n"; - return 5; - } - - auto moduleAST = parseInputFile(inputFilename); - if (!moduleAST) - return 1; - - dump(*moduleAST); - return 0; -} - -int dumpLLVMIR(mlir::ModuleOp module) { - // Register the translation to LLVM IR with the MLIR context. - mlir::registerBuiltinDialectTranslation(*module->getContext()); - mlir::registerLLVMDialectTranslation(*module->getContext()); - - // Convert the module to LLVM IR in a new LLVM IR context. - llvm::LLVMContext llvmContext; - auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); - if (!llvmModule) { - llvm::errs() << "Failed to emit LLVM IR\n"; - return -1; - } - - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - - // Create target machine and configure the LLVM Module - auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); - if (!tmBuilderOrError) { - llvm::errs() << "Could not create JITTargetMachineBuilder\n"; - return -1; - } - - auto tmOrError = tmBuilderOrError->createTargetMachine(); - if (!tmOrError) { - llvm::errs() << "Could not create TargetMachine\n"; - return -1; - } - mlir::ExecutionEngine::setupTargetTripleAndDataLayout(llvmModule.get(), - tmOrError.get().get()); - - /// Optionally run an optimization pipeline over the llvm module. - auto optPipeline = mlir::makeOptimizingTransformer( - /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, - /*targetMachine=*/nullptr); - if (auto err = optPipeline(llvmModule.get())) { - llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; - return -1; - } - llvm::errs() << *llvmModule << "\n"; - return 0; -} - -int runJit(mlir::ModuleOp module) { - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - - // Register the translation from MLIR to LLVM IR, which must happen before we - // can JIT-compile. - mlir::registerBuiltinDialectTranslation(*module->getContext()); - mlir::registerLLVMDialectTranslation(*module->getContext()); - - // An optimization pipeline to use within the execution engine. - auto optPipeline = mlir::makeOptimizingTransformer( - /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, - /*targetMachine=*/nullptr); - - // Create an MLIR execution engine. The execution engine eagerly JIT-compiles - // the module. - mlir::ExecutionEngineOptions engineOptions; - engineOptions.transformer = optPipeline; - auto maybeEngine = mlir::ExecutionEngine::create(module, engineOptions); - assert(maybeEngine && "failed to construct an execution engine"); - auto &engine = maybeEngine.get(); - - // Invoke the JIT-compiled function. - auto invocationResult = engine->invokePacked("main"); - if (invocationResult) { - llvm::errs() << "JIT invocation failed\n"; - return -1; - } - - return 0; -} - -int main(int argc, char **argv) { - // Register any command line options. - mlir::registerAsmPrinterCLOptions(); - mlir::registerMLIRContextCLOptions(); - mlir::registerPassManagerCLOptions(); - - cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); - - if (emitAction == Action::DumpAST) - return dumpAST(); - - // If we aren't dumping the AST, then we are compiling with/to MLIR. - mlir::DialectRegistry registry; - mlir::func::registerAllExtensions(registry); - - mlir::MLIRContext context(registry); - // Load our Dialect in this MLIR Context. - context.getOrLoadDialect(); - - mlir::OwningOpRef module; - if (int error = loadAndProcessMLIR(context, module)) - return error; - - // If we aren't exporting to non-mlir, then we are done. - bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; - if (isOutputingMLIR) { - module->dump(); - return 0; - } - - // Check to see if we are compiling to LLVM IR. - if (emitAction == Action::DumpLLVMIR) - return dumpLLVMIR(*module); - - // Otherwise, we must be running the jit. - if (emitAction == Action::RunJIT) - return runJit(*module); - - llvm::errs() << "No action specified (parsing only?), use -emit=\n"; - return -1; -} diff --git a/mlir/example/README.md b/mlir/example/README.md index 1dacef1..dc6e0cb 100644 --- a/mlir/example/README.md +++ b/mlir/example/README.md @@ -476,35 +476,6 @@ $ ./build/Ch7/mlir-example-ch7 Ch7/struct-codegen.toy -emit=jit # 9.000000 36.000000 ``` -- Ch8 - -```bash -$ ./build/Ch8/mlir-example-ch8 Ch8/matmul.toy.mlir -emit=mlir -# module { -# toy.func private @matmul_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { -# %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> -# %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> -# %2 = toy.matmul(%0 : tensor<*xf64>, %1 : tensor<*xf64>) to tensor<*xf64> -# toy.return %2 : tensor<*xf64> -# } -# toy.func @main() { -# %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> -# %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> -# %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> -# %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<3x2xf64> -# %4 = toy.generic_call @matmul_transpose(%1, %3) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<*xf64> -# toy.print %4 : tensor<*xf64> -# toy.return -# } -# } -``` - -```bash -$ ./build/Ch8/mlir-example-ch8 Ch8/matmul.toy -emit=jit -# 14.000000 32.000000 -# 32.000000 77.000000 -``` - ### Transform Dialect Please flow the [mlir-transform-tutorial](https://mlir.llvm.org/docs/Tutorials/transform/). If you have some questions about the way to run these examples, please check the top lines of each mlir files. diff --git a/mlir/example/scripts/build_deps.sh b/mlir/example/scripts/build_deps.sh index 6468e3d..c118cac 100644 --- a/mlir/example/scripts/build_deps.sh +++ b/mlir/example/scripts/build_deps.sh @@ -19,10 +19,16 @@ set -e # exit 1 #fi +if [[ -f "/usr/bin/git" ]]; then + WORKSPACEROOT=$(git rev-parse --show-toplevel)/mlir/example || WORKSPACEROOT=`pwd` +fi + +cd ${WORKSPACEROOT} + # LLVM source -LLVM_SRC_DIR="${1:-third_party/llvm-project}" +LLVM_SRC_DIR="${1:-${WORKSPACEROOT}/third_party/llvm-project}" build_dir="${LLVM_SRC_DIR}/build" -install_dir="${2:-third_party}"/llvm +install_dir="${2:-${WORKSPACEROOT}/third_party/llvm}" if ! [ -f "$LLVM_SRC_DIR/llvm/CMakeLists.txt" ]; then echo "Expected the path to LLVM to be set correctly (got '$LLVM_SRC_DIR'): can't find CMakeLists.txt" @@ -42,13 +48,15 @@ mkdir -p ${install_dir} echo "Beginning build (commands will echo)" set -x +cd $LLVM_SRC_DIR + cmake -GNinja \ - "-H$LLVM_SRC_DIR/llvm" \ - "-B$build_dir" \ + "-H llvm" \ + "-B $build_dir" \ -DCMAKE_BUILD_TYPE=Debug \ -DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \ - -DLLVM_ENABLE_LLD=ON \ + -DLLVM_ENABLE_LLD=OFF \ -DLLVM_ENABLE_BACKTRACES=OFF \ -DLLVM_INCLUDE_UTILS=ON \ -DCMAKE_INSTALL_PREFIX=${install_dir} \ diff --git a/mlir/example/scripts/sync_deps.sh b/mlir/example/scripts/sync_deps.sh index a1edad4..079c0c4 100644 --- a/mlir/example/scripts/sync_deps.sh +++ b/mlir/example/scripts/sync_deps.sh @@ -2,4 +2,4 @@ mkdir -p third_party -git clone -b release/19.x --depth 1 https://github.com/llvm/llvm-project.git third_party/llvm-project +git clone -b llvmorg-22.1.0 --depth 1 https://github.com/llvm/llvm-project.git third_party/llvm-project diff --git a/mlir/example/transform_Ch2/include/MyExtension.td b/mlir/example/transform_Ch2/include/MyExtension.td index 15cd1e6..1abd952 100644 --- a/mlir/example/transform_Ch2/include/MyExtension.td +++ b/mlir/example/transform_Ch2/include/MyExtension.td @@ -19,29 +19,29 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" -// Define the new operation. By convention, prefix its name with the name of the dialect +// Define the new operation. By convention, prefix its name with the name of the dialect // extension, "my.". The full operation name will be further prefixed with "transform.". def ChangeCallTargetOp : Op, DeclareOpInterfaceMethods]> { - // Provide a brief and a full description. It is recommended that the latter describes + // Provide a brief and a full description. It is recommended that the latter describes // the effects on the operands and how the operation processes various failure modes. let summary = "Changes the callee of a call operation to the specified one"; let description = [{ - For each `func.call` payload operation associated with the handle, changes its + For each `func.call` payload operation associated with the handle, changes its callee to be the symbol whose name is provided as an attribute to this operation. - Generates a silenceable failure if the operand is associated with payload operations + Generates a silenceable failure if the operand is associated with payload operations that are not `func.call`. Only reads the operand. }]; - // The arguments include the handle to the payload operations and the attribute that - // specifies the new callee. The handle must implement TransformHandleTypeInterface. - // We use a string attribute as the symbol may not exist in the transform IR so the - // verification may fail. + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. let arguments = (ins TransformHandleTypeInterface:$call, StrAttr:$new_target); diff --git a/mlir/example/transform_Ch2/lib/MyExtension.cpp b/mlir/example/transform_Ch2/lib/MyExtension.cpp index 68d538a..b4b27e9 100644 --- a/mlir/example/transform_Ch2/lib/MyExtension.cpp +++ b/mlir/example/transform_Ch2/lib/MyExtension.cpp @@ -29,6 +29,9 @@ class MyExtension : public ::mlir::transform::TransformDialectExtension { public: + // The TypeID of this extension. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) + // The extension must derive the base constructor. using Base::Base; diff --git a/mlir/example/transform_Ch3/include/MyExtension.td b/mlir/example/transform_Ch3/include/MyExtension.td index 7944f91..49874a7 100644 --- a/mlir/example/transform_Ch3/include/MyExtension.td +++ b/mlir/example/transform_Ch3/include/MyExtension.td @@ -21,34 +21,34 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" -// Define the new operation. By convention, prefix its name with the name of the dialect +// Define the new operation. By convention, prefix its name with the name of the dialect // extension, "my.". The full operation name will be further prefixed with "transform.". def ChangeCallTargetOp : Op]> { - // Provide a brief and a full description. It is recommended that the latter describes + // Provide a brief and a full description. It is recommended that the latter describes // the effects on the operands and how the operation processes various failure modes. let summary = "Changes the callee of a call operation to the specified one"; let description = [{ - For each `func.call` payload operation associated with the handle, changes its + For each `func.call` payload operation associated with the handle, changes its callee to be the symbol whose name is provided as an attribute to this operation. - Generates a silenceable failure if the operand is associated with payload operations + Generates a silenceable failure if the operand is associated with payload operations that are not `func.call`. Only reads the operand. }]; - // The arguments include the handle to the payload operations and the attribute that - // specifies the new callee. The handle must implement TransformHandleTypeInterface. - // We use a string attribute as the symbol may not exist in the transform IR so the - // verification may fail. + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. let arguments = (ins - // Specify the type constraint on the input accepting only `func.call` payload - // operations. - Transform_ConcreteOpType<"func.call">:$call, + // Allow the handle to be to concrete func.call ops as well as any op implementing + // the CallOpInterface. + AnyTypeOf<[Transform_ConcreteOpType<"func.call">, CallOpInterfaceHandle]>:$call, StrAttr:$new_target); // The results are empty as the transformation does not produce any new payload. @@ -80,7 +80,7 @@ def CallToOp : Op]> { - // The usual components of a type such as description, mnemonic and assembly format + // The usual components of a type such as description, mnemonic and assembly format // should be provided. let summary = "handle to payload operations implementing CallOpInterface"; let mnemonic = "my.call_op_interface"; diff --git a/mlir/example/transform_Ch3/lib/MyExtension.cpp b/mlir/example/transform_Ch3/lib/MyExtension.cpp index f7a9942..4b2123f 100644 --- a/mlir/example/transform_Ch3/lib/MyExtension.cpp +++ b/mlir/example/transform_Ch3/lib/MyExtension.cpp @@ -35,6 +35,9 @@ class MyExtension : public ::mlir::transform::TransformDialectExtension { public: + // The TypeID of this extension. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) + // The extension must derive the base constructor. using Base::Base; diff --git a/mlir/example/transform_Ch4/lib/MyExtension.cpp b/mlir/example/transform_Ch4/lib/MyExtension.cpp index 38c8ca1..2159483 100644 --- a/mlir/example/transform_Ch4/lib/MyExtension.cpp +++ b/mlir/example/transform_Ch4/lib/MyExtension.cpp @@ -13,11 +13,9 @@ #include "MyExtension.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" -#define DEBUG_TYPE_MATCHER "transform-matcher" -#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") -#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) +#define DEBUG_TYPE "transform-matcher" #define GET_OP_CLASSES #include "MyExtension.cpp.inc" @@ -31,6 +29,9 @@ class MyExtension : public ::mlir::transform::TransformDialectExtension { public: + // The TypeID of this extension. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyExtension) + // The extension must derive the base constructor. using Base::Base; @@ -121,9 +122,8 @@ mlir::transform::HasOperandSatisfyingOp::apply( // Report failure-to-match for debugging purposes and stop matching this // operand. assert(diag.isSilenceableFailure()); - DEBUG_MATCHER(DBGS_MATCHER() - << "failed to match operand #" << operand.getOperandNumber() - << ": " << diag.getMessage()); + LDBG() << "failed to match operand #" << operand.getOperandNumber() + << ": " << diag.getMessage(); (void)diag.silence(); matchSucceeded = false; break;