[Mlir-commits] [mlir] 00c95b1 - [mlir][transform] Add C++ builder to SequenceOp

Matthias Springer llvmlistbot at llvm.org
Thu Nov 17 07:01:57 PST 2022


Author: Matthias Springer
Date: 2022-11-17T15:58:13+01:00
New Revision: 00c95b19d7963dae4e8bdee66a9880d44761cffe

URL: https://github.com/llvm/llvm-project/commit/00c95b19d7963dae4e8bdee66a9880d44761cffe
DIFF: https://github.com/llvm/llvm-project/commit/00c95b19d7963dae4e8bdee66a9880d44761cffe.diff

LOG: [mlir][transform] Add C++ builder to SequenceOp

This change adds a builder that populates the body of a SequenceOp. This is useful for constructing SequenceOps from C++.

Differential Revision: https://reviews.llvm.org/D137710

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
    mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
index b9f4eed2b200d..822c24d28bc57 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h
@@ -21,6 +21,10 @@ namespace mlir {
 namespace transform {
 enum class FailurePropagationMode : uint32_t;
 class FailurePropagationModeAttr;
+
+/// A builder function that populates the body of a SequenceOp.
+using SequenceBodyBuilderFn = ::llvm::function_ref<void(
+    ::mlir::OpBuilder &, ::mlir::Location, ::mlir::BlockArgument)>;
 } // namespace transform
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 42f8d5cb27698..d81bea63821a6 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -389,6 +389,10 @@ def SequenceOp : TransformDialectOp<"sequence",
     IR, typically the root operation of the pass interpreting the transform
     dialect. Operand omission is only allowed for sequences not contained in
     another sequence.
+
+    The body of the sequence terminates with an implicit or explicit
+    `transform.yield` op. The operands of the terminator are returned as the
+    results of the sequence op.
   }];
 
   let arguments = (ins FailurePropagationMode:$failure_propagation_mode,
@@ -400,6 +404,20 @@ def SequenceOp : TransformDialectOp<"sequence",
     "($root^ `:` type($root))? (`->` type($results)^)? `failures` `(` "
     "$failure_propagation_mode `)` attr-dict-with-keyword regions";
 
+  let builders = [
+    // Build a sequence with a root.
+    OpBuilder<(ins
+        "::mlir::TypeRange":$resultTypes,
+        "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
+        "::mlir::Value":$root, "SequenceBodyBuilderFn":$bodyBuilder)>,
+
+    // Build a sequence without a root but a certain bbArg type.
+    OpBuilder<(ins
+        "::mlir::TypeRange":$resultTypes,
+        "::mlir::transform::FailurePropagationMode":$failure_propagation_mode,
+        "::mlir::Type":$bbArgType, "SequenceBodyBuilderFn":$bodyBuilder)>
+  ];
+
   let extraClassDeclaration = [{
     /// Allow the dialect prefix to be omitted.
     static StringRef getDefaultDialect() { return "transform"; }

diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index af759d96c774a..76e0c89adb7d4 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -765,6 +765,39 @@ void transform::SequenceOp::getRegionInvocationBounds(
   bounds.emplace_back(1, 1);
 }
 
+void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
+                                  TypeRange resultTypes,
+                                  FailurePropagationMode failurePropagationMode,
+                                  Value root,
+                                  SequenceBodyBuilderFn bodyBuilder) {
+  build(builder, state, resultTypes, failurePropagationMode, root);
+  Region *region = state.regions.back().get();
+  auto bbArgType = root.getType();
+  Block *bodyBlock = builder.createBlock(
+      region, region->begin(), TypeRange{bbArgType}, {state.location});
+
+  // Populate body.
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(bodyBlock);
+  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+}
+
+void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
+                                  TypeRange resultTypes,
+                                  FailurePropagationMode failurePropagationMode,
+                                  Type bbArgType,
+                                  SequenceBodyBuilderFn bodyBuilder) {
+  build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value());
+  Region *region = state.regions.back().get();
+  Block *bodyBlock = builder.createBlock(
+      region, region->begin(), TypeRange{bbArgType}, {state.location});
+
+  // Populate body.
+  OpBuilder::InsertionGuard guard(builder);
+  builder.setInsertionPointToStart(bodyBlock);
+  bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
+}
+
 //===----------------------------------------------------------------------===//
 // WithPDLPatternsOp
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list