[Mlir-commits] [mlir] 8ddd98f - [mlir][linalg] Return newly created ops from bufferize_to_allocation

Matthias Springer llvmlistbot at llvm.org
Tue Jul 11 07:34:14 PDT 2023


Author: Matthias Springer
Date: 2023-07-11T16:34:02+02:00
New Revision: 8ddd98f83136646bd4f88fe35919b0af03337b75

URL: https://github.com/llvm/llvm-project/commit/8ddd98f83136646bd4f88fe35919b0af03337b75
DIFF: https://github.com/llvm/llvm-project/commit/8ddd98f83136646bd4f88fe35919b0af03337b75.diff

LOG: [mlir][linalg] Return newly created ops from bufferize_to_allocation

Return all ops that were generated as part of the bufferization, so that users do not have to match them in the enclosing op.

Differential Revision: https://reviews.llvm.org/D154966

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
    mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 2979a8018cdf3a..4a143a158867cb 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -87,8 +87,8 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
     result in a new allocation. It replaces all original uses of the target
     result with the newly allocated buffer, wrapped in a
     `bufferization.to_tensor` op. It returns a handle to the newly allocated
-    buffer. Furthermore, it returns a handle to the result of the `to_tensor`
-    op.
+    buffer. Furthermore, it returns a handle that is mapped to all newly created
+    ops.
 
     Only bufferizable ops are that bufferize to a memory write or have an
     aliasing OpOperand (and do not themselves bufferize to an allocation) are
@@ -121,12 +121,13 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
     #### Return modes
 
     This operation consumes the `target` handle and produces the
-    `allocated_buffer` handle. It always succeeds.
+    `allocated_buffer` and `new_ops` handles. It always succeeds.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        OptionalAttr<AnyAttr>:$memory_space);
-  let results = (outs Transform_AnyValue:$allocated_buffer);
+  let results = (outs Transform_AnyValue:$allocated_buffer,
+                      Transform_AnyOpType:$new_ops);
   let assemblyFormat = "$target attr-dict `:` type($target)";
   
   let builders = [

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index fdabbcc05181dd..6625ef553eba21 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -444,7 +444,7 @@ class RewriterBase : public OpBuilder {
   /// struct can be used as a base to create listener chains, so that multiple
   /// listeners can be notified of IR changes.
   struct ForwardingListener : public RewriterBase::Listener {
-    ForwardingListener(Listener *listener) : listener(listener) {}
+    ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
 
     void notifyOperationInserted(Operation *op) override {
       listener->notifyOperationInserted(op);
@@ -453,26 +453,32 @@ class RewriterBase : public OpBuilder {
       listener->notifyBlockCreated(block);
     }
     void notifyOperationModified(Operation *op) override {
-      listener->notifyOperationModified(op);
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyOperationModified(op);
     }
     void notifyOperationReplaced(Operation *op, Operation *newOp) override {
-      listener->notifyOperationReplaced(op, newOp);
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyOperationReplaced(op, newOp);
     }
     void notifyOperationReplaced(Operation *op,
                                  ValueRange replacement) override {
-      listener->notifyOperationReplaced(op, replacement);
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyOperationReplaced(op, replacement);
     }
     void notifyOperationRemoved(Operation *op) override {
-      listener->notifyOperationRemoved(op);
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyOperationRemoved(op);
     }
     LogicalResult notifyMatchFailure(
         Location loc,
         function_ref<void(Diagnostic &)> reasonCallback) override {
-      return listener->notifyMatchFailure(loc, reasonCallback);
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        return rewriteListener->notifyMatchFailure(loc, reasonCallback);
+      return failure();
     }
 
   private:
-    Listener *listener;
+    OpBuilder::Listener *listener;
   };
 
   /// Move the blocks that belong to "region" before the given position in

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ca6b272103b7a..5474377ee364b0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -177,8 +177,11 @@ void transform::BufferizeToAllocationOp::build(OpBuilder &b,
                                                OperationState &result,
                                                Value target,
                                                Attribute memorySpace) {
+  SmallVector<Type> resultTypes;
+  resultTypes.push_back(b.getType<transform::AnyValueType>());
+  resultTypes.push_back(b.getType<transform::AnyOpType>());
   return build(b, result,
-               /*resultTypes=*/b.getType<transform::AnyValueType>(),
+               /*resultTypes=*/resultTypes,
                /*target=*/target,
                /*memorySpace=*/memorySpace);
 }
@@ -187,15 +190,52 @@ void transform::BufferizeToAllocationOp::build(OpBuilder &b,
                                                OperationState &result,
                                                Value target,
                                                int64_t memorySpace) {
+  SmallVector<Type> resultTypes;
+  resultTypes.push_back(b.getType<transform::AnyValueType>());
+  resultTypes.push_back(b.getType<transform::AnyOpType>());
   return build(b, result,
-               /*resultTypes=*/b.getType<transform::AnyValueType>(),
+               /*resultTypes=*/resultTypes,
                /*target=*/target,
                /*memorySpace=*/b.getI64IntegerAttr(memorySpace));
 }
 
+namespace {
+class NewOpsListener : public RewriterBase::ForwardingListener {
+public:
+  using RewriterBase::ForwardingListener::ForwardingListener;
+
+  SmallVector<Operation *> getNewOps() const {
+    return SmallVector<Operation *>(newOps.begin(), newOps.end());
+  }
+
+private:
+  void notifyOperationInserted(Operation *op) override {
+    ForwardingListener::notifyOperationInserted(op);
+    auto inserted = newOps.insert(op);
+    (void)inserted;
+    assert(inserted.second && "expected newly created op");
+  }
+
+  void notifyOperationRemoved(Operation *op) override {
+    ForwardingListener::notifyOperationRemoved(op);
+    op->walk([&](Operation *op) { newOps.erase(op); });
+  }
+
+  DenseSet<Operation *> newOps;
+};
+} // namespace
+
 DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
     transform::TransformRewriter &rewriter,
     transform::TransformResults &results, transform::TransformState &state) {
+  // Attach listener to keep track of newly created ops.
+  OpBuilder::Listener *previousListener = rewriter.getListener();
+  auto resetListener =
+      llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
+  NewOpsListener newOpsListener(previousListener);
+  rewriter.setListener(&newOpsListener);
+
+  // Bufferize ops.
   Attribute memorySpace =
       getMemorySpace().has_value() ? getMemorySpace().value() : Attribute();
   SmallVector<Value> allocatedBuffers;
@@ -209,7 +249,10 @@ DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(
     }
     allocatedBuffers.push_back(buffer);
   }
+
+  // Set results.
   results.setValues(cast<OpResult>(getAllocatedBuffer()), allocatedBuffers);
+  results.set(cast<OpResult>(getNewOps()), newOpsListener.getNewOps());
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -217,6 +260,7 @@ void transform::BufferizeToAllocationOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   consumesHandle(getTarget(), effects);
   producesHandle(getAllocatedBuffer(), effects);
+  producesHandle(getNewOps(), effects);
   modifiesPayload(effects);
 }
 

diff  --git a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
index 55b0ff5016d177..2643935049c9f3 100644
--- a/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
+++ b/mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir
@@ -54,7 +54,7 @@ transform.sequence failures(propagate) {
     padding_dimensions=[0, 1, 2],
     pack_paddings=[1, 1, 1]
   } : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
-  %buffer = transform.structured.bufferize_to_allocation %pad {memory_space = 3} : !transform.any_op
+  %buffer, %new_ops = transform.structured.bufferize_to_allocation %pad {memory_space = 3} : !transform.any_op
   %2 = transform.bufferization.one_shot_bufferize %arg1 {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
 
 }
@@ -114,6 +114,6 @@ transform.sequence failures(propagate) {
   transform.structured.masked_vectorize %pad vector_sizes [10, 12] : !transform.any_op
   %vector_write = transform.structured.match ops{["vector.transfer_write"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   %mask_op = transform.get_parent_op %vector_write {op_name = "vector.mask"} : (!transform.any_op) -> !transform.any_op
-  %buffer = transform.structured.bufferize_to_allocation %mask_op {memory_space = 3} : !transform.any_op
+  %buffer, %new_ops = transform.structured.bufferize_to_allocation %mask_op {memory_space = 3} : !transform.any_op
   %2 = transform.bufferization.one_shot_bufferize %arg1 {bufferize_function_boundaries=true} : (!transform.any_op) -> !transform.any_op
 }

diff  --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index e51c20334e9554..45efde3b077a44 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -32,7 +32,17 @@ func.func @tensor_pad_constant(%t: tensor<?x10xindex>, %l2: index, %h1: index,
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  %2 = transform.structured.bufferize_to_allocation %0 : !transform.any_op
+  %2, %new = transform.structured.bufferize_to_allocation %0 : !transform.any_op
+
+  // Ensure that one linalg.fill was generated.
+  %fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+  // expected-remark @below{{1}}
+  test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+
+  // Ensure that one memref.tensor_store was generated.
+  %tensor_store = transform.select "memref.tensor_store" in %new : (!transform.any_op) -> !transform.any_op
+  // expected-remark @below{{1}}
+  test_print_number_of_associated_payload_ir_ops %tensor_store : !transform.any_op
 }
 
 // -----
@@ -57,7 +67,7 @@ func.func @tensor_pad_constant(%t: tensor<?x10xindex>, %l2: index, %h1: index,
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  %2 = transform.structured.bufferize_to_allocation %0 : !transform.any_op
+  %2, %new = transform.structured.bufferize_to_allocation %0 : !transform.any_op
   // Make sure that One-Shot Bufferize can bufferize the rest.
   %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
 }
@@ -81,7 +91,7 @@ func.func @tensor_insert(%t: tensor<?x10xindex>, %idx: index, %v: index) -> tens
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
+  %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
   // Make sure that One-Shot Bufferize can bufferize the rest.
   %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
 }
@@ -104,7 +114,7 @@ func.func @tensor_insert_into_empty(%idx: index, %v: index) -> tensor<10xindex>
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.insert"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
+  %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
   // Make sure that One-Shot Bufferize can bufferize the rest.
   %4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
 }
@@ -121,7 +131,7 @@ transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["tensor.extract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
   // expected-error @below{{failed to bufferize operation}}
-  %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
+  %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
 }
 
 // -----
@@ -142,5 +152,5 @@ func.func @vector_mask(%t: tensor<?xf32>, %val: vector<16xf32>, %idx: index, %m0
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !transform.any_op):
   %0 = transform.structured.match ops{["vector.mask"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-  %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
+  %2, %new = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
 }


        


More information about the Mlir-commits mailing list