[Mlir-commits] [mlir] df969f6 - [mlir] fixes to transform::SequenceOp

Alex Zinenko llvmlistbot at llvm.org
Mon Dec 12 01:20:46 PST 2022


Author: Alex Zinenko
Date: 2022-12-12T10:20:38+01:00
New Revision: df969f66ef43fea537503369651a12429ba1cde1

URL: https://github.com/llvm/llvm-project/commit/df969f66ef43fea537503369651a12429ba1cde1
DIFF: https://github.com/llvm/llvm-project/commit/df969f66ef43fea537503369651a12429ba1cde1.diff

LOG: [mlir] fixes to transform::SequenceOp

Harden the verifier to check that the block argument type matches the
operand type, when present. This was overlooked when transform dialect
types were introduced.

Fix the builders to preserve the insertion point before creating the
block, otherwise the insertion point is updated to be within the block
by `createBlock` and never reset to be after the sequence op itself,
leading all following operations to be created in the unexpected to
the caller place.

Reviewed By: chelini, springerm

Differential Revision: https://reviews.llvm.org/D139427

Added: 
    

Modified: 
    mlir/lib/Dialect/Transform/IR/TransformOps.cpp
    mlir/test/Dialect/Transform/ops-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 76f6485c9ff9b..e3607b2f0b96c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -645,13 +645,22 @@ checkDoubleConsume(Value value,
 }
 
 LogicalResult transform::SequenceOp::verify() {
+  assert(getBodyBlock()->getNumArguments() == 1 &&
+         "the number of arguments must have been verified to be 1 by "
+         "PossibleTopLevelTransformOpTrait");
+
+  BlockArgument arg = getBodyBlock()->getArgument(0);
+  if (getRoot()) {
+    if (arg.getType() != getRoot().getType()) {
+      return emitOpError() << "expects the type of the block argument to match "
+                              "the type of the operand";
+    }
+  }
+
   // Check if the block argument has more than one consuming use.
-  for (BlockArgument argument : getBodyBlock()->getArguments()) {
-    auto report = [&]() {
-      return (emitOpError() << "block argument #" << argument.getArgNumber());
-    };
-    if (failed(checkDoubleConsume(argument, report)))
-      return failure();
+  if (failed(checkDoubleConsume(
+          arg, [this]() { return (emitOpError() << "block argument #0"); }))) {
+    return failure();
   }
 
   // Check properties of the nested operations they cannot check themselves.
@@ -765,12 +774,12 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
                                   SequenceBodyBuilderFn bodyBuilder) {
   build(builder, state, resultTypes, failurePropagationMode, root);
   Region *region = state.regions.back().get();
-  auto bbArgType = root.getType();
+  Type bbArgType = root.getType();
+  OpBuilder::InsertionGuard guard(builder);
   Block *bodyBlock = builder.createBlock(
       region, region->begin(), TypeRange{bbArgType}, {state.location});
 
   // Populate body.
-  OpBuilder::InsertionGuard guard(builder);
   builder.setInsertionPointToStart(bodyBlock);
   bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
 }
@@ -782,11 +791,11 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state,
                                   SequenceBodyBuilderFn bodyBuilder) {
   build(builder, state, resultTypes, failurePropagationMode, /*root=*/Value());
   Region *region = state.regions.back().get();
+  OpBuilder::InsertionGuard guard(builder);
   Block *bodyBlock = builder.createBlock(
       region, region->begin(), TypeRange{bbArgType}, {state.location});
 
   // Populate body.
-  OpBuilder::InsertionGuard guard(builder);
   builder.setInsertionPointToStart(bodyBlock);
   bodyBuilder(builder, state.location, bodyBlock->getArgument(0));
 }

diff  --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 98a197970d4bf..369be213d3ac4 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -35,6 +35,17 @@ transform.sequence failures(propagate) {
 
 // -----
 
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error @below {{expects the type of the block argument to match the type of the operand}}
+  transform.sequence %arg0: !transform.any_op failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    transform.yield
+  }
+}
+
+// -----
+
 // expected-note @below {{nested in another possible top-level op}}
 transform.with_pdl_patterns {
 ^bb0(%arg0: !pdl.operation):


        


More information about the Mlir-commits mailing list