[Mlir-commits] [mlir] 95b4e88 - [mlir:PDLL] Add support for PDL MLIR code generation
River Riddle
llvmlistbot at llvm.org
Sat Feb 26 11:26:38 PST 2022
Author: River Riddle
Date: 2022-02-26T11:08:51-08:00
New Revision: 95b4e88b1db348fbb074c945bd85c777cf807cc0
URL: https://github.com/llvm/llvm-project/commit/95b4e88b1db348fbb074c945bd85c777cf807cc0
DIFF: https://github.com/llvm/llvm-project/commit/95b4e88b1db348fbb074c945bd85c777cf807cc0.diff
LOG: [mlir:PDLL] Add support for PDL MLIR code generation
This commits starts to plumb PDLL down into MLIR and adds an initial
PDL generator. After this commit, we will have conceptually support
end-to-end execution of PDLL. Followups will add CPP generation to
match the current DRR setup, and begin to add various end-to-end
tests to test PDLL execution.
Differential Revision: https://reviews.llvm.org/D119779
Added:
mlir/include/mlir/Tools/PDLL/CodeGen/MLIRGen.h
mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt
mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
mlir/test/mlir-pdll/CodeGen/MLIR/decl.pdll
mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
mlir/test/mlir-pdll/CodeGen/MLIR/stmt.pdll
mlir/test/mlir-pdll/lit.local.cfg
Modified:
mlir/include/mlir/Tools/PDLL/AST/Nodes.h
mlir/lib/Tools/PDLL/CMakeLists.txt
mlir/tools/mlir-pdll/CMakeLists.txt
mlir/tools/mlir-pdll/mlir-pdll.cpp
Removed:
mlir/test/mlir-pdll/Parser/lit.local.cfg
################################################################################
diff --git a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
index 6824354a16edd..03c0215129fa8 100644
--- a/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
+++ b/mlir/include/mlir/Tools/PDLL/AST/Nodes.h
@@ -1239,8 +1239,8 @@ inline bool CoreConstraintDecl::classof(const Node *node) {
}
inline bool Expr::classof(const Node *node) {
- return isa<AttributeExpr, DeclRefExpr, MemberAccessExpr, OperationExpr,
- TupleExpr, TypeExpr>(node);
+ return isa<AttributeExpr, CallExpr, DeclRefExpr, MemberAccessExpr,
+ OperationExpr, TupleExpr, TypeExpr>(node);
}
inline bool OpRewriteStmt::classof(const Node *node) {
diff --git a/mlir/include/mlir/Tools/PDLL/CodeGen/MLIRGen.h b/mlir/include/mlir/Tools/PDLL/CodeGen/MLIRGen.h
new file mode 100644
index 0000000000000..c691c39318c12
--- /dev/null
+++ b/mlir/include/mlir/Tools/PDLL/CodeGen/MLIRGen.h
@@ -0,0 +1,41 @@
+//===- MLIRGen.h - MLIR PDLL Code Generation --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_PDLL_CODEGEN_MLIRGEN_H_
+#define MLIR_TOOLS_PDLL_CODEGEN_MLIRGEN_H_
+
+#include <memory>
+
+#include "mlir/Support/LogicalResult.h"
+
+namespace llvm {
+class SourceMgr;
+} // namespace llvm
+
+namespace mlir {
+class MLIRContext;
+class ModuleOp;
+template <typename OpT>
+class OwningOpRef;
+
+namespace pdll {
+namespace ast {
+class Context;
+class Module;
+} // namespace ast
+
+/// Given a PDLL module, generate an MLIR PDL pattern module within the given
+/// MLIR context.
+OwningOpRef<ModuleOp> codegenPDLLToMLIR(MLIRContext *mlirContext,
+ const ast::Context &context,
+ const llvm::SourceMgr &sourceMgr,
+ const ast::Module &module);
+} // namespace pdll
+} // namespace mlir
+
+#endif // MLIR_TOOLS_PDLL_CODEGEN_MLIRGEN_H_
diff --git a/mlir/lib/Tools/PDLL/CMakeLists.txt b/mlir/lib/Tools/PDLL/CMakeLists.txt
index efdfcd37bf744..ac83f5e5fae77 100644
--- a/mlir/lib/Tools/PDLL/CMakeLists.txt
+++ b/mlir/lib/Tools/PDLL/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(AST)
+add_subdirectory(CodeGen)
add_subdirectory(Parser)
diff --git a/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt b/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt
new file mode 100644
index 0000000000000..f1e59126623fc
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/CodeGen/CMakeLists.txt
@@ -0,0 +1,9 @@
+add_mlir_library(MLIRPDLLCodeGen
+ MLIRGen.cpp
+
+ LINK_LIBS PUBLIC
+ MLIRParser
+ MLIRPDLLAST
+ MLIRPDL
+ MLIRSupport
+ )
diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
new file mode 100644
index 0000000000000..17afb3bcc47ed
--- /dev/null
+++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
@@ -0,0 +1,586 @@
+//===- MLIRGen.cpp --------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDL/IR/PDLOps.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Parser.h"
+#include "mlir/Tools/PDLL/AST/Context.h"
+#include "mlir/Tools/PDLL/AST/Nodes.h"
+#include "mlir/Tools/PDLL/AST/Types.h"
+#include "llvm/ADT/ScopedHashTable.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::pdll;
+
+//===----------------------------------------------------------------------===//
+// CodeGen
+//===----------------------------------------------------------------------===//
+
+namespace {
+class CodeGen {
+public:
+ CodeGen(MLIRContext *mlirContext, const ast::Context &context,
+ const llvm::SourceMgr &sourceMgr)
+ : builder(mlirContext), sourceMgr(sourceMgr) {
+ // Make sure that the PDL dialect is loaded.
+ mlirContext->loadDialect<pdl::PDLDialect>();
+ }
+
+ OwningOpRef<ModuleOp> generate(const ast::Module &module);
+
+private:
+ /// Generate an MLIR location from the given source location.
+ Location genLoc(llvm::SMLoc loc);
+ Location genLoc(llvm::SMRange loc) { return genLoc(loc.Start); }
+
+ /// Generate an MLIR type from the given source type.
+ Type genType(ast::Type type);
+
+ /// Generate MLIR for the given AST node.
+ void gen(const ast::Node *node);
+
+ //===--------------------------------------------------------------------===//
+ // Statements
+ //===--------------------------------------------------------------------===//
+
+ void genImpl(const ast::CompoundStmt *stmt);
+ void genImpl(const ast::EraseStmt *stmt);
+ void genImpl(const ast::LetStmt *stmt);
+ void genImpl(const ast::ReplaceStmt *stmt);
+ void genImpl(const ast::RewriteStmt *stmt);
+ void genImpl(const ast::ReturnStmt *stmt);
+
+ //===--------------------------------------------------------------------===//
+ // Decls
+ //===--------------------------------------------------------------------===//
+
+ void genImpl(const ast::UserConstraintDecl *decl);
+ void genImpl(const ast::UserRewriteDecl *decl);
+ void genImpl(const ast::PatternDecl *decl);
+
+ /// Generate the set of MLIR values defined for the given variable decl, and
+ /// apply any attached constraints.
+ SmallVector<Value> genVar(const ast::VariableDecl *varDecl);
+
+ /// Generate the value for a variable that does not have an initializer
+ /// expression, i.e. create the PDL value based on the type/constraints of the
+ /// variable.
+ Value genNonInitializerVar(const ast::VariableDecl *varDecl, Location loc);
+
+ /// Apply the constraints of the given variable to `values`, which correspond
+ /// to the MLIR values of the variable.
+ void applyVarConstraints(const ast::VariableDecl *varDecl, ValueRange values);
+
+ //===--------------------------------------------------------------------===//
+ // Expressions
+ //===--------------------------------------------------------------------===//
+
+ Value genSingleExpr(const ast::Expr *expr);
+ SmallVector<Value> genExpr(const ast::Expr *expr);
+ Value genExprImpl(const ast::AttributeExpr *expr);
+ SmallVector<Value> genExprImpl(const ast::CallExpr *expr);
+ SmallVector<Value> genExprImpl(const ast::DeclRefExpr *expr);
+ Value genExprImpl(const ast::MemberAccessExpr *expr);
+ Value genExprImpl(const ast::OperationExpr *expr);
+ SmallVector<Value> genExprImpl(const ast::TupleExpr *expr);
+ Value genExprImpl(const ast::TypeExpr *expr);
+
+ SmallVector<Value> genConstraintCall(const ast::UserConstraintDecl *decl,
+ Location loc, ValueRange inputs);
+ SmallVector<Value> genRewriteCall(const ast::UserRewriteDecl *decl,
+ Location loc, ValueRange inputs);
+ template <typename PDLOpT, typename T>
+ SmallVector<Value> genConstraintOrRewriteCall(const T *decl, Location loc,
+ ValueRange inputs);
+
+ //===--------------------------------------------------------------------===//
+ // Fields
+ //===--------------------------------------------------------------------===//
+
+ /// The MLIR builder used for building the resultant IR.
+ OpBuilder builder;
+
+ /// A map from variable declarations to the MLIR equivalent.
+ using VariableMapTy =
+ llvm::ScopedHashTable<const ast::VariableDecl *, SmallVector<Value>>;
+ VariableMapTy variables;
+
+ /// The source manager of the PDLL ast.
+ const llvm::SourceMgr &sourceMgr;
+};
+} // namespace
+
+OwningOpRef<ModuleOp> CodeGen::generate(const ast::Module &module) {
+ OwningOpRef<ModuleOp> mlirModule =
+ builder.create<ModuleOp>(genLoc(module.getLoc()));
+ builder.setInsertionPointToStart(mlirModule->getBody());
+
+ // Generate code for each of the decls within the module.
+ for (const ast::Decl *decl : module.getChildren())
+ gen(decl);
+
+ return mlirModule;
+}
+
+Location CodeGen::genLoc(llvm::SMLoc loc) {
+ unsigned fileID = sourceMgr.FindBufferContainingLoc(loc);
+
+ // TODO: Fix performance issues in SourceMgr::getLineAndColumn so that we can
+ // use it here.
+ auto &bufferInfo = sourceMgr.getBufferInfo(fileID);
+ unsigned lineNo = bufferInfo.getLineNumber(loc.getPointer());
+ unsigned column =
+ (loc.getPointer() - bufferInfo.getPointerForLineNumber(lineNo)) + 1;
+ auto *buffer = sourceMgr.getMemoryBuffer(fileID);
+
+ return FileLineColLoc::get(builder.getContext(),
+ buffer->getBufferIdentifier(), lineNo, column);
+}
+
+Type CodeGen::genType(ast::Type type) {
+ return TypeSwitch<ast::Type, Type>(type)
+ .Case([&](ast::AttributeType astType) -> Type {
+ return builder.getType<pdl::AttributeType>();
+ })
+ .Case([&](ast::OperationType astType) -> Type {
+ return builder.getType<pdl::OperationType>();
+ })
+ .Case([&](ast::TypeType astType) -> Type {
+ return builder.getType<pdl::TypeType>();
+ })
+ .Case([&](ast::ValueType astType) -> Type {
+ return builder.getType<pdl::ValueType>();
+ })
+ .Case([&](ast::RangeType astType) -> Type {
+ return pdl::RangeType::get(genType(astType.getElementType()));
+ });
+}
+
+void CodeGen::gen(const ast::Node *node) {
+ TypeSwitch<const ast::Node *>(node)
+ .Case<const ast::CompoundStmt, const ast::EraseStmt, const ast::LetStmt,
+ const ast::ReplaceStmt, const ast::RewriteStmt,
+ const ast::ReturnStmt, const ast::UserConstraintDecl,
+ const ast::UserRewriteDecl, const ast::PatternDecl>(
+ [&](auto derivedNode) { this->genImpl(derivedNode); })
+ .Case([&](const ast::Expr *expr) { genExpr(expr); });
+}
+
+//===----------------------------------------------------------------------===//
+// CodeGen: Statements
+//===----------------------------------------------------------------------===//
+
+void CodeGen::genImpl(const ast::CompoundStmt *stmt) {
+ VariableMapTy::ScopeTy varScope(variables);
+ for (const ast::Stmt *childStmt : stmt->getChildren())
+ gen(childStmt);
+}
+
+/// If the given builder is nested under a PDL PatternOp, build a rewrite
+/// operation and update the builder to nest under it. This is necessary for
+/// PDLL operation rewrite statements that are directly nested within a Pattern.
+static void checkAndNestUnderRewriteOp(OpBuilder &builder, Value rootExpr,
+ Location loc) {
+ if (isa<pdl::PatternOp>(builder.getInsertionBlock()->getParentOp())) {
+ pdl::RewriteOp rewrite = builder.create<pdl::RewriteOp>(
+ loc, rootExpr, /*name=*/StringAttr(),
+ /*externalArgs=*/ValueRange(), /*externalConstParams=*/ArrayAttr());
+ builder.createBlock(&rewrite.body());
+ }
+}
+
+void CodeGen::genImpl(const ast::EraseStmt *stmt) {
+ OpBuilder::InsertionGuard insertGuard(builder);
+ Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
+ Location loc = genLoc(stmt->getLoc());
+
+ // Make sure we are nested in a RewriteOp.
+ OpBuilder::InsertionGuard guard(builder);
+ checkAndNestUnderRewriteOp(builder, rootExpr, loc);
+ builder.create<pdl::EraseOp>(loc, rootExpr);
+}
+
+void CodeGen::genImpl(const ast::LetStmt *stmt) { genVar(stmt->getVarDecl()); }
+
+void CodeGen::genImpl(const ast::ReplaceStmt *stmt) {
+ OpBuilder::InsertionGuard insertGuard(builder);
+ Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
+ Location loc = genLoc(stmt->getLoc());
+
+ // Make sure we are nested in a RewriteOp.
+ OpBuilder::InsertionGuard guard(builder);
+ checkAndNestUnderRewriteOp(builder, rootExpr, loc);
+
+ SmallVector<Value> replValues;
+ for (ast::Expr *replExpr : stmt->getReplExprs())
+ replValues.push_back(genSingleExpr(replExpr));
+
+ // Check to see if the statement has a replacement operation, or a range of
+ // replacement values.
+ bool usesReplOperation =
+ replValues.size() == 1 &&
+ replValues.front().getType().isa<pdl::OperationType>();
+ builder.create<pdl::ReplaceOp>(
+ loc, rootExpr, usesReplOperation ? replValues[0] : Value(),
+ usesReplOperation ? ValueRange() : ValueRange(replValues));
+}
+
+void CodeGen::genImpl(const ast::RewriteStmt *stmt) {
+ OpBuilder::InsertionGuard insertGuard(builder);
+ Value rootExpr = genSingleExpr(stmt->getRootOpExpr());
+
+ // Make sure we are nested in a RewriteOp.
+ OpBuilder::InsertionGuard guard(builder);
+ checkAndNestUnderRewriteOp(builder, rootExpr, genLoc(stmt->getLoc()));
+ gen(stmt->getRewriteBody());
+}
+
+void CodeGen::genImpl(const ast::ReturnStmt *stmt) {
+ // ReturnStmt generation is handled by the respective constraint or rewrite
+ // parent node.
+}
+
+//===----------------------------------------------------------------------===//
+// CodeGen: Decls
+//===----------------------------------------------------------------------===//
+
+void CodeGen::genImpl(const ast::UserConstraintDecl *decl) {
+ // All PDLL constraints get inlined when called, and the main native
+ // constraint declarations doesn't require any MLIR to be generated, only uses
+ // of it do.
+}
+
+void CodeGen::genImpl(const ast::UserRewriteDecl *decl) {
+ // All PDLL rewrites get inlined when called, and the main native
+ // rewrite declarations doesn't require any MLIR to be generated, only uses
+ // of it do.
+}
+
+void CodeGen::genImpl(const ast::PatternDecl *decl) {
+ const ast::Name *name = decl->getName();
+
+ // FIXME: Properly model HasBoundedRecursion in PDL so that we don't drop it
+ // here.
+ pdl::PatternOp pattern = builder.create<pdl::PatternOp>(
+ genLoc(decl->getLoc()), decl->getBenefit(),
+ name ? Optional<StringRef>(name->getName()) : Optional<StringRef>());
+
+ OpBuilder::InsertionGuard savedInsertPoint(builder);
+ builder.setInsertionPointToStart(pattern.getBody());
+ gen(decl->getBody());
+}
+
+SmallVector<Value> CodeGen::genVar(const ast::VariableDecl *varDecl) {
+ auto it = variables.begin(varDecl);
+ if (it != variables.end())
+ return *it;
+
+ // If the variable has an initial value, use that as the base value.
+ // Otherwise, generate a value using the constraint list.
+ SmallVector<Value> values;
+ if (const ast::Expr *initExpr = varDecl->getInitExpr())
+ values = genExpr(initExpr);
+ else
+ values.push_back(genNonInitializerVar(varDecl, genLoc(varDecl->getLoc())));
+
+ // Apply the constraints of the values of the variable.
+ applyVarConstraints(varDecl, values);
+
+ variables.insert(varDecl, values);
+ return values;
+}
+
+Value CodeGen::genNonInitializerVar(const ast::VariableDecl *varDecl,
+ Location loc) {
+ // A functor used to generate expressions nested
+ auto getTypeConstraint = [&]() -> Value {
+ for (const ast::ConstraintRef &constraint : varDecl->getConstraints()) {
+ Value typeValue =
+ TypeSwitch<const ast::Node *, Value>(constraint.constraint)
+ .Case<ast::AttrConstraintDecl, ast::ValueConstraintDecl,
+ ast::ValueRangeConstraintDecl>([&](auto *cst) -> Value {
+ if (auto *typeConstraintExpr = cst->getTypeExpr())
+ return genSingleExpr(typeConstraintExpr);
+ return Value();
+ })
+ .Default(Value());
+ if (typeValue)
+ return typeValue;
+ }
+ return Value();
+ };
+
+ // Generate a value based on the type of the variable.
+ ast::Type type = varDecl->getType();
+ Type mlirType = genType(type);
+ if (type.isa<ast::ValueType>())
+ return builder.create<pdl::OperandOp>(loc, mlirType, getTypeConstraint());
+ if (type.isa<ast::TypeType>())
+ return builder.create<pdl::TypeOp>(loc, mlirType, /*type=*/TypeAttr());
+ if (type.isa<ast::AttributeType>())
+ return builder.create<pdl::AttributeOp>(loc, getTypeConstraint());
+ if (ast::OperationType opType = type.dyn_cast<ast::OperationType>()) {
+ Value operands = builder.create<pdl::OperandsOp>(
+ loc, pdl::RangeType::get(builder.getType<pdl::ValueType>()),
+ /*type=*/Value());
+ Value results = builder.create<pdl::TypesOp>(
+ loc, pdl::RangeType::get(builder.getType<pdl::TypeType>()),
+ /*types=*/ArrayAttr());
+ return builder.create<pdl::OperationOp>(loc, opType.getName(), operands,
+ llvm::None, ValueRange(), results);
+ }
+
+ if (ast::RangeType rangeTy = type.dyn_cast<ast::RangeType>()) {
+ ast::Type eleTy = rangeTy.getElementType();
+ if (eleTy.isa<ast::ValueType>())
+ return builder.create<pdl::OperandsOp>(loc, mlirType,
+ getTypeConstraint());
+ if (eleTy.isa<ast::TypeType>())
+ return builder.create<pdl::TypesOp>(loc, mlirType, /*types=*/ArrayAttr());
+ }
+
+ llvm_unreachable("invalid non-initialized variable type");
+}
+
+void CodeGen::applyVarConstraints(const ast::VariableDecl *varDecl,
+ ValueRange values) {
+ // Generate calls to any user constraints that were attached via the
+ // constraint list.
+ for (const ast::ConstraintRef &ref : varDecl->getConstraints())
+ if (const auto *userCst = dyn_cast<ast::UserConstraintDecl>(ref.constraint))
+ genConstraintCall(userCst, genLoc(ref.referenceLoc), values);
+}
+
+//===----------------------------------------------------------------------===//
+// CodeGen: Expressions
+//===----------------------------------------------------------------------===//
+
+Value CodeGen::genSingleExpr(const ast::Expr *expr) {
+ return TypeSwitch<const ast::Expr *, Value>(expr)
+ .Case<const ast::AttributeExpr, const ast::MemberAccessExpr,
+ const ast::OperationExpr, const ast::TypeExpr>(
+ [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
+ .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
+ [&](auto derivedNode) {
+ SmallVector<Value> results = this->genExprImpl(derivedNode);
+ assert(results.size() == 1 && "expected single expression result");
+ return results[0];
+ });
+}
+
+SmallVector<Value> CodeGen::genExpr(const ast::Expr *expr) {
+ return TypeSwitch<const ast::Expr *, SmallVector<Value>>(expr)
+ .Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
+ [&](auto derivedNode) { return this->genExprImpl(derivedNode); })
+ .Default([&](const ast::Expr *expr) -> SmallVector<Value> {
+ return {genSingleExpr(expr)};
+ });
+}
+
+Value CodeGen::genExprImpl(const ast::AttributeExpr *expr) {
+ Attribute attr = parseAttribute(expr->getValue(), builder.getContext());
+ assert(attr && "invalid MLIR attribute data");
+ return builder.create<pdl::AttributeOp>(genLoc(expr->getLoc()), attr);
+}
+
+SmallVector<Value> CodeGen::genExprImpl(const ast::CallExpr *expr) {
+ Location loc = genLoc(expr->getLoc());
+ SmallVector<Value> arguments;
+ for (const ast::Expr *arg : expr->getArguments())
+ arguments.push_back(genSingleExpr(arg));
+
+ // Resolve the callable expression of this call.
+ auto *callableExpr = dyn_cast<ast::DeclRefExpr>(expr->getCallableExpr());
+ assert(callableExpr && "unhandled CallExpr callable");
+
+ // Generate the PDL based on the type of callable.
+ const ast::Decl *callable = callableExpr->getDecl();
+ if (const auto *decl = dyn_cast<ast::UserConstraintDecl>(callable))
+ return genConstraintCall(decl, loc, arguments);
+ if (const auto *decl = dyn_cast<ast::UserRewriteDecl>(callable))
+ return genRewriteCall(decl, loc, arguments);
+ llvm_unreachable("unhandled CallExpr callable");
+}
+
+SmallVector<Value> CodeGen::genExprImpl(const ast::DeclRefExpr *expr) {
+ if (const auto *varDecl = dyn_cast<ast::VariableDecl>(expr->getDecl()))
+ return genVar(varDecl);
+ llvm_unreachable("unknown decl reference expression");
+}
+
+Value CodeGen::genExprImpl(const ast::MemberAccessExpr *expr) {
+ Location loc = genLoc(expr->getLoc());
+ StringRef name = expr->getMemberName();
+ SmallVector<Value> parentExprs = genExpr(expr->getParentExpr());
+ ast::Type parentType = expr->getParentExpr()->getType();
+
+ // Handle operation based member access.
+ if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
+ if (isa<ast::AllResultsMemberAccessExpr>(expr)) {
+ Type mlirType = genType(expr->getType());
+ if (mlirType.isa<pdl::ValueType>())
+ return builder.create<pdl::ResultOp>(loc, mlirType, parentExprs[0],
+ builder.getI32IntegerAttr(0));
+ return builder.create<pdl::ResultsOp>(loc, mlirType, parentExprs[0]);
+ }
+ llvm_unreachable("unhandled operation member access expression");
+ }
+
+ // Handle tuple based member access.
+ if (auto tupleType = parentType.dyn_cast<ast::TupleType>()) {
+ auto elementNames = tupleType.getElementNames();
+
+ // The index is either a numeric index, or a name.
+ unsigned index = 0;
+ if (llvm::isDigit(name[0]))
+ name.getAsInteger(/*Radix=*/10, index);
+ else
+ index = llvm::find(elementNames, name) - elementNames.begin();
+
+ assert(index < parentExprs.size() && "invalid result index");
+ return parentExprs[index];
+ }
+
+ llvm_unreachable("unhandled member access expression");
+}
+
+Value CodeGen::genExprImpl(const ast::OperationExpr *expr) {
+ Location loc = genLoc(expr->getLoc());
+ Optional<StringRef> opName = expr->getName();
+
+ // Operands.
+ SmallVector<Value> operands;
+ for (const ast::Expr *operand : expr->getOperands())
+ operands.push_back(genSingleExpr(operand));
+
+ // Attributes.
+ SmallVector<StringRef> attrNames;
+ SmallVector<Value> attrValues;
+ for (const ast::NamedAttributeDecl *attr : expr->getAttributes()) {
+ attrNames.push_back(attr->getName().getName());
+ attrValues.push_back(genSingleExpr(attr->getValue()));
+ }
+
+ // Results.
+ SmallVector<Value> results;
+ for (const ast::Expr *result : expr->getResultTypes())
+ results.push_back(genSingleExpr(result));
+
+ return builder.create<pdl::OperationOp>(loc, opName, operands, attrNames,
+ attrValues, results);
+}
+
+SmallVector<Value> CodeGen::genExprImpl(const ast::TupleExpr *expr) {
+ SmallVector<Value> elements;
+ for (const ast::Expr *element : expr->getElements())
+ elements.push_back(genSingleExpr(element));
+ return elements;
+}
+
+Value CodeGen::genExprImpl(const ast::TypeExpr *expr) {
+ Type type = parseType(expr->getValue(), builder.getContext());
+ assert(type && "invalid MLIR type data");
+ return builder.create<pdl::TypeOp>(genLoc(expr->getLoc()),
+ builder.getType<pdl::TypeType>(),
+ TypeAttr::get(type));
+}
+
+SmallVector<Value>
+CodeGen::genConstraintCall(const ast::UserConstraintDecl *decl, Location loc,
+ ValueRange inputs) {
+ // Apply any constraints defined on the arguments to the input values.
+ for (auto it : llvm::zip(decl->getInputs(), inputs))
+ applyVarConstraints(std::get<0>(it), std::get<1>(it));
+
+ // Generate the constraint call.
+ SmallVector<Value> results =
+ genConstraintOrRewriteCall<pdl::ApplyNativeConstraintOp>(decl, loc,
+ inputs);
+
+ // Apply any constraints defined on the results of the constraint.
+ for (auto it : llvm::zip(decl->getResults(), results))
+ applyVarConstraints(std::get<0>(it), std::get<1>(it));
+ return results;
+}
+
+SmallVector<Value> CodeGen::genRewriteCall(const ast::UserRewriteDecl *decl,
+ Location loc, ValueRange inputs) {
+ return genConstraintOrRewriteCall<pdl::ApplyNativeRewriteOp>(decl, loc,
+ inputs);
+}
+
+template <typename PDLOpT, typename T>
+SmallVector<Value> CodeGen::genConstraintOrRewriteCall(const T *decl,
+ Location loc,
+ ValueRange inputs) {
+ const ast::CompoundStmt *cstBody = decl->getBody();
+
+ // If the decl doesn't have a statement body, it is a native decl.
+ if (!cstBody) {
+ ast::Type declResultType = decl->getResultType();
+ SmallVector<Type> resultTypes;
+ if (ast::TupleType tupleType = declResultType.dyn_cast<ast::TupleType>()) {
+ for (ast::Type type : tupleType.getElementTypes())
+ resultTypes.push_back(genType(type));
+ } else {
+ resultTypes.push_back(genType(declResultType));
+ }
+
+ // FIXME: We currently do not have a modeling for the "constant params"
+ // support PDL provides. We should either figure out a modeling for this, or
+ // refactor the support within PDL to be something a bit more reasonable for
+ // what we need as a frontend.
+ Operation *pdlOp = builder.create<PDLOpT>(loc, resultTypes,
+ decl->getName().getName(), inputs,
+ /*params=*/ArrayAttr());
+ return pdlOp->getResults();
+ }
+
+ // Otherwise, this is a PDLL decl.
+ VariableMapTy::ScopeTy varScope(variables);
+
+ // Map the inputs of the call to the decl arguments.
+ // Note: This is only valid because we do not support recursion, meaning
+ // we don't need to worry about conflicting mappings here.
+ for (auto it : llvm::zip(inputs, decl->getInputs()))
+ variables.insert(std::get<1>(it), {std::get<0>(it)});
+
+ // Visit the body of the call as normal.
+ gen(cstBody);
+
+ // If the decl has no results, there is nothing to do.
+ if (cstBody->getChildren().empty())
+ return SmallVector<Value>();
+ auto *returnStmt = dyn_cast<ast::ReturnStmt>(cstBody->getChildren().back());
+ if (!returnStmt)
+ return SmallVector<Value>();
+
+ // Otherwise, grab the results from the return statement.
+ return genExpr(returnStmt->getResultExpr());
+}
+
+//===----------------------------------------------------------------------===//
+// MLIRGen
+//===----------------------------------------------------------------------===//
+
+OwningOpRef<ModuleOp> mlir::pdll::codegenPDLLToMLIR(
+ MLIRContext *mlirContext, const ast::Context &context,
+ const llvm::SourceMgr &sourceMgr, const ast::Module &module) {
+ CodeGen codegen(mlirContext, context, sourceMgr);
+ OwningOpRef<ModuleOp> mlirModule = codegen.generate(module);
+ if (failed(verify(*mlirModule)))
+ return nullptr;
+ return mlirModule;
+}
diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/decl.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/decl.pdll
new file mode 100644
index 0000000000000..14098f55dfb9e
--- /dev/null
+++ b/mlir/test/mlir-pdll/CodeGen/MLIR/decl.pdll
@@ -0,0 +1,97 @@
+// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// PatternDecl
+//===----------------------------------------------------------------------===//
+
+// CHECK: pdl.pattern : benefit(0) {
+Pattern => erase _: Op;
+
+// -----
+
+// CHECK: pdl.pattern @NamedPattern : benefit(0) {
+Pattern NamedPattern => erase _: Op;
+
+// -----
+
+// CHECK: pdl.pattern @NamedPattern : benefit(10) {
+Pattern NamedPattern with benefit(10), recursion => erase _: Op;
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// VariableDecl
+//===----------------------------------------------------------------------===//
+
+// Test the case of a variable with an initializer.
+
+// CHECK: pdl.pattern @VarWithInit
+// CHECK: %[[INIT:.*]] = operation "test.op"
+// CHECK: rewrite %[[INIT]] {
+// CHECK: erase %[[INIT]]
+Pattern VarWithInit {
+ let var = op<test.op>;
+ erase var;
+}
+
+// -----
+
+// Test range based constraints.
+
+// CHECK: pdl.pattern @VarWithRangeConstraints
+// CHECK: %[[OPERAND_TYPES:.*]] = types
+// CHECK: %[[OPERANDS:.*]] = operands : %[[OPERAND_TYPES]]
+// CHECK: %[[RESULT_TYPES:.*]] = types
+// CHECK: operation(%[[OPERANDS]] : !pdl.range<value>) -> (%[[RESULT_TYPES]] : !pdl.range<type>)
+Pattern VarWithRangeConstraints {
+ erase op<>(operands: ValueRange<operandTypes: TypeRange>) -> (results: TypeRange);
+}
+
+// -----
+
+// Test single entity constraints.
+
+// CHECK: pdl.pattern @VarWithConstraints
+// CHECK: %[[OPERAND_TYPE:.*]] = type
+// CHECK: %[[OPERAND:.*]] = operand : %[[OPERAND_TYPES]]
+// CHECK: %[[ATTR_TYPE:.*]] = type
+// CHECK: %[[ATTR:.*]] = attribute : %[[ATTR_TYPE]]
+// CHECK: %[[RESULT_TYPE:.*]] = type
+// CHECK: operation(%[[OPERAND]] : !pdl.value) {"attr" = %[[ATTR]]} -> (%[[RESULT_TYPE]] : !pdl.type)
+Pattern VarWithConstraints {
+ erase op<>(operand: Value<operandType: Type>) { attr = _: Attr<attrType: Type>} -> (result: Type);
+}
+
+// -----
+
+// Test op constraint.
+
+// CHECK: pdl.pattern @VarWithNoNameOpConstraint
+// CHECK: %[[OPERANDS:.*]] = operands
+// CHECK: %[[RESULT_TYPES:.*]] = types
+// CHECK: operation(%[[OPERANDS]] : !pdl.range<value>) -> (%[[RESULT_TYPES]] : !pdl.range<type>)
+Pattern VarWithNoNameOpConstraint => erase _: Op;
+
+// CHECK: pdl.pattern @VarWithNamedOpConstraint
+// CHECK: %[[OPERANDS:.*]] = operands
+// CHECK: %[[RESULT_TYPES:.*]] = types
+// CHECK: operation "test.op"(%[[OPERANDS]] : !pdl.range<value>) -> (%[[RESULT_TYPES]] : !pdl.range<type>)
+Pattern VarWithNamedOpConstraint => erase _: Op<test.op>;
+
+// -----
+
+// Test user defined constraints.
+
+// CHECK: pdl.pattern @VarWithUserConstraint
+// CHECK: %[[OPERANDS:.*]] = operands
+// CHECK: %[[RESULT_TYPES:.*]] = types
+// CHECK: %[[OP:.*]] = operation(%[[OPERANDS]] : !pdl.range<value>) -> (%[[RESULT_TYPES]] : !pdl.range<type>)
+// CHECK: apply_native_constraint "NestedArgCst"(%[[OP]] : !pdl.operation)
+// CHECK: apply_native_constraint "NestedResCst"(%[[OP]] : !pdl.operation)
+// CHECK: apply_native_constraint "OpCst"(%[[OP]] : !pdl.operation)
+// CHECK: rewrite %[[OP]]
+Constraint NestedArgCst(op: Op);
+Constraint NestedResCst(op: Op);
+Constraint TestArgResCsts(op: NestedArgCst) -> NestedResCst => op;
+Constraint OpCst(op: Op);
+Pattern VarWithUserConstraint => erase _: [TestArgResCsts, OpCst];
diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
new file mode 100644
index 0000000000000..4205e56fad54b
--- /dev/null
+++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
@@ -0,0 +1,93 @@
+// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// AttributeExpr
+//===----------------------------------------------------------------------===//
+
+// CHECK: pdl.pattern @AttrExpr
+// CHECK: %[[ATTR:.*]] = attribute 10
+// CHECK: operation {"attr" = %[[ATTR]]}
+Pattern AttrExpr => erase op<> { attr = attr<"10"> };
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// CallExpr
+//===----------------------------------------------------------------------===//
+
+// CHECK: pdl.pattern @TestCallWithArgsAndReturn
+// CHECK: %[[ROOT:.*]] = operation
+// CHECK: rewrite %[[ROOT]]
+// CHECK: %[[REPL_OP:.*]] = operation "test.op"
+// CHECK: %[[RESULTS:.*]] = results of %[[REPL_OP]]
+// CHECK: replace %[[ROOT]] with(%[[RESULTS]] : !pdl.range<value>)
+Rewrite TestRewrite(root: Op) -> ValueRange => root;
+Pattern TestCallWithArgsAndReturn => replace root: Op with TestRewrite(op<test.op>);
+
+// -----
+
+// CHECK: pdl.pattern @TestExternalCall
+// CHECK: %[[ROOT:.*]] = operation
+// CHECK: rewrite %[[ROOT]]
+// CHECK: %[[RESULTS:.*]] = apply_native_rewrite "TestRewrite"(%[[ROOT]] : !pdl.operation) : !pdl.range<value>
+// CHECK: replace %[[ROOT]] with(%[[RESULTS]] : !pdl.range<value>)
+Rewrite TestRewrite(op: Op) -> ValueRange;
+Pattern TestExternalCall => replace root: Op with TestRewrite(root);
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// MemberAccessExpr
+//===----------------------------------------------------------------------===//
+
+// Handle implicit "all" operation results access.
+// CHECK: pdl.pattern @OpAllResultMemberAccess
+// CHECK: %[[OP0:.*]] = operation
+// CHECK: %[[OP0_RES:.*]] = result 0 of %[[OP0]]
+// CHECK: %[[OP1:.*]] = operation
+// CHECK: %[[OP1_RES:.*]] = results of %[[OP1]]
+// CHECK: operation(%[[OP0_RES]], %[[OP1_RES]] : !pdl.value, !pdl.range<value>)
+Pattern OpAllResultMemberAccess {
+ let singleVar: Value = op<>;
+ let rangeVar: ValueRange = op<>;
+ erase op<>(singleVar, rangeVar);
+}
+
+// -----
+
+// CHECK: pdl.pattern @TupleMemberAccessNumber
+// CHECK: %[[FIRST:.*]] = operation "test.first"
+// CHECK: %[[SECOND:.*]] = operation "test.second"
+// CHECK: rewrite %[[FIRST]] {
+// CHECK: replace %[[FIRST]] with %[[SECOND]]
+Pattern TupleMemberAccessNumber {
+ let firstOp = op<test.first>;
+ let secondOp = op<test.second>(firstOp);
+ let tuple = (firstOp, secondOp);
+ replace tuple.0 with tuple.1;
+}
+
+// -----
+
+// CHECK: pdl.pattern @TupleMemberAccessName
+// CHECK: %[[FIRST:.*]] = operation "test.first"
+// CHECK: %[[SECOND:.*]] = operation "test.second"
+// CHECK: rewrite %[[FIRST]] {
+// CHECK: replace %[[FIRST]] with %[[SECOND]]
+Pattern TupleMemberAccessName {
+ let firstOp = op<test.first>;
+ let secondOp = op<test.second>(firstOp);
+ let tuple = (first = firstOp, second = secondOp);
+ replace tuple.first with tuple.second;
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// TypeExpr
+//===----------------------------------------------------------------------===//
+
+// CHECK: pdl.pattern @TypeExpr
+// CHECK: %[[TYPE:.*]] = type : i32
+// CHECK: operation -> (%[[TYPE]] : !pdl.type)
+Pattern TypeExpr => erase op<> -> (type<"i32">);
diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/stmt.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/stmt.pdll
new file mode 100644
index 0000000000000..0fc13caca83a2
--- /dev/null
+++ b/mlir/test/mlir-pdll/CodeGen/MLIR/stmt.pdll
@@ -0,0 +1,61 @@
+// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// EraseStmt
+//===----------------------------------------------------------------------===//
+
+// CHECK: pdl.pattern @EraseStmt
+// CHECK: %[[OP:.*]] = operation
+// CHECK: rewrite %[[OP]]
+// CHECK: erase %[[OP]]
+Pattern EraseStmt => erase op<>;
+
+// -----
+
+// CHECK: pdl.pattern @EraseStmtNested
+// CHECK: %[[OP:.*]] = operation
+// CHECK: rewrite %[[OP]]
+// CHECK: erase %[[OP]]
+Pattern EraseStmtNested => rewrite root: Op with { erase root; };
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// ReplaceStmt
+//===----------------------------------------------------------------------===//
+
+// CHECK: pdl.pattern @ReplaceStmt
+// CHECK: %[[OPERANDS:.*]] = operands
+// CHECK: %[[OP:.*]] = operation(%[[OPERANDS]]
+// CHECK: rewrite %[[OP]]
+// CHECK: replace %[[OP]] with(%[[OPERANDS]] : !pdl.range<value>)
+Pattern ReplaceStmt => replace op<>(operands: ValueRange) with operands;
+
+// -----
+
+// CHECK: pdl.pattern @ReplaceStmtNested
+// CHECK: %[[OPERANDS:.*]] = operands
+// CHECK: %[[OP:.*]] = operation(%[[OPERANDS]]
+// CHECK: rewrite %[[OP]]
+// CHECK: replace %[[OP]] with(%[[OPERANDS]] : !pdl.range<value>)
+Pattern ReplaceStmtNested {
+ let root = op<>(operands: ValueRange);
+ rewrite root with { replace root with operands; };
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// RewriteStmt
+//===----------------------------------------------------------------------===//
+
+// CHECK: pdl.pattern @RewriteStmtNested
+// CHECK: %[[OP:.*]] = operation
+// CHECK: rewrite %[[OP]]
+// CHECK: erase %[[OP]]
+Pattern RewriteStmtNested {
+ rewrite root: Op with {
+ rewrite root with { erase root; };
+ };
+}
+
diff --git a/mlir/test/mlir-pdll/Parser/lit.local.cfg b/mlir/test/mlir-pdll/lit.local.cfg
similarity index 100%
rename from mlir/test/mlir-pdll/Parser/lit.local.cfg
rename to mlir/test/mlir-pdll/lit.local.cfg
diff --git a/mlir/tools/mlir-pdll/CMakeLists.txt b/mlir/tools/mlir-pdll/CMakeLists.txt
index d8ec702180978..4573e39928b74 100644
--- a/mlir/tools/mlir-pdll/CMakeLists.txt
+++ b/mlir/tools/mlir-pdll/CMakeLists.txt
@@ -1,5 +1,6 @@
set(LIBS
MLIRPDLLAST
+ MLIRPDLLCodeGen
MLIRPDLLParser
)
diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp
index 5600e66eafdb3..6ee2b8bf11a55 100644
--- a/mlir/tools/mlir-pdll/mlir-pdll.cpp
+++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp
@@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/ToolUtilities.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
+#include "mlir/Tools/PDLL/CodeGen/MLIRGen.h"
#include "mlir/Tools/PDLL/Parser/Parser.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
@@ -26,6 +28,7 @@ using namespace mlir::pdll;
/// The desired output type.
enum class OutputType {
AST,
+ MLIR,
};
static LogicalResult
@@ -40,12 +43,18 @@ processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
if (failed(module))
return failure();
- switch (outputType) {
- case OutputType::AST:
+ if (outputType == OutputType::AST) {
(*module)->print(os);
- break;
+ return success();
}
+ MLIRContext mlirContext;
+ OwningOpRef<ModuleOp> pdlModule =
+ codegenPDLLToMLIR(&mlirContext, astContext, sourceMgr, **module);
+ if (!pdlModule)
+ return failure();
+
+ pdlModule->print(os, OpPrintingFlags().enableDebugInfo());
return success();
}
@@ -71,7 +80,9 @@ int main(int argc, char **argv) {
"x", llvm::cl::init(OutputType::AST),
llvm::cl::desc("The type of output desired"),
llvm::cl::values(clEnumValN(OutputType::AST, "ast",
- "generate the AST for the input file")));
+ "generate the AST for the input file"),
+ clEnumValN(OutputType::MLIR, "mlir",
+ "generate the PDL MLIR for the input file")));
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv, "PDLL Frontend");
More information about the Mlir-commits
mailing list