[Mlir-commits] [mlir] eb74eff - [mlir][linalg] BufferizeToAllocationOp: Support vector.mask

Matthias Springer llvmlistbot at llvm.org
Tue Jul 4 05:54:09 PDT 2023


Author: Matthias Springer
Date: 2023-07-04T14:53:43+02:00
New Revision: eb74eff9d2271a820beebc2814e63e0e1016a336

URL: https://github.com/llvm/llvm-project/commit/eb74eff9d2271a820beebc2814e63e0e1016a336
DIFF: https://github.com/llvm/llvm-project/commit/eb74eff9d2271a820beebc2814e63e0e1016a336.diff

LOG: [mlir][linalg] BufferizeToAllocationOp: Support vector.mask

This op needs special handling because the allocation for the masked op must be placed outside of the mask op.

Differential Revision: https://reviews.llvm.org/D154058

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
    mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 312cc7b5efef14..3830d65b99e38e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -108,9 +108,12 @@ def BufferizeToAllocationOp : Op<Transform_Dialect,
     %0 = bufferization.to_tensor %alloc restrict writable : memref<10xf32>
     ```
 
-    Selected ops that bufferize to an allocation are also supported:
+    Selected ops that bufferize to an allocation (or need special handling) are
+    also supported:
     - `tensor.pad` is lowered to an allocation, followed by a `linalg.fill` and
       and a buffer copy (all on memrefs).
+    - `vector.mask` is bufferized together with its region. The allocation is
+      placed in front of the `vector.mask` op.
 
     An optional memory space attribute can be specified for the materialized
     buffer allocation.

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 64faf0df59d1ef..968488cead1720 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -60,9 +60,34 @@ std::optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp);
 /// %0 = bufferization.to_tensor %alloc restrict writable
 ///
 /// In addition to rewriting the IR as shown above, this function returns the
-/// newly allocated buffer.
+/// newly allocated buffer. The `insertionPoint` parameter can be used to
+/// specify a custom insertion point for the buffer allocation.
 Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp,
-                            Attribute memorySpace = {});
+                            Attribute memorySpace = {},
+                            Operation *insertionPoint = nullptr);
+
+/// Materialize a buffer allocation for the given vector.mask op and bufferize
+/// the op, including its region. E.g.:
+///
+/// %0 = vector.mask {
+///   vector.transfer_write %v, %t : vector<16xf32>, tensor<?xf32>
+/// } : vector<16xi1> -> tensor<?xf32>
+///
+/// is lowered to:
+///
+/// %alloc = memref.alloc
+/// memref.tensor_store %t, %subview
+/// vector.mask {
+///   vector.transfer_write %arg0, %alloc : vector<16xf32>, memref<?xf32>
+/// } : vector<16xi1>
+/// %0 = bufferization.to_tensor %alloc restrict writable
+///
+/// In addition to rewriting the IR as shown above, this function returns the
+/// newly allocated buffer. The `insertionPoint` parameter can be used to
+/// specify a custom insertion point for the buffer allocation.
+Value bufferizeToAllocation(RewriterBase &rewriter, vector::MaskOp maskOp,
+                            Attribute memorySpace = {},
+                            Operation *insertionPoint = nullptr);
 
 /// Bufferize the given op with tensor semantics and materialize the result in
 /// a newly allocated buffer.
@@ -72,10 +97,17 @@ Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp,
 /// supported. They are bufferized using their BufferizableOpInterface
 /// implementation.
 ///
-/// Selected ops that bufferize to an allocation are also supported:
+/// Selected ops that bufferize to an allocation (or need special handling) are
+/// also supported:
 /// - tensor.pad
+/// - vector.mask
+///
+/// This function returns the newly allocated buffer. The `insertionPoint`
+/// parameter can be used to specify a custom insertion point for the buffer
+/// allocation.
 Value bufferizeToAllocation(RewriterBase &rewriter, Operation *op,
-                            Attribute memorySpace = {});
+                            Attribute memorySpace = {},
+                            Operation *insertionPoint = nullptr);
 
 /// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a
 /// LinalgOp. This transforms looks for LinalgOps that have an unused output

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index b2d7fe2f58b180..570f6d255f097d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -170,15 +170,16 @@ static Value createAllocationForTensor(RewriterBase &rewriter, Location loc,
 }
 
 Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp,
-                                    Attribute memorySpace) {
+                                    Attribute memorySpace,
+                                    Operation *insertionPoint) {
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(padOp);
+  rewriter.setInsertionPoint(insertionPoint ? insertionPoint : padOp);
   Location loc = padOp.getLoc();
 
   // Create buffer allocation.
   Value alloc =
       createAllocationForTensor(rewriter, loc, padOp.getResult(), memorySpace);
-  rewriter.setInsertionPointAfter(alloc.getDefiningOp());
+  rewriter.setInsertionPoint(padOp);
 
   // Create linalg.fill or linalg.generic.
   Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, alloc);
@@ -201,6 +202,66 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, PadOp padOp,
   return alloc;
 }
 
+Value linalg::bufferizeToAllocation(RewriterBase &rewriter,
+                                    vector::MaskOp maskOp,
+                                    Attribute memorySpace,
+                                    Operation *insertionPoint) {
+  assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
+         "expected single masked op");
+  OpBuilder::InsertionGuard g(rewriter);
+  bufferization::BufferizationOptions options;
+  Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
+  assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
+
+  // Bufferize maskable op. By default, place the buffer allocation right before
+  // the mask op.
+  Value alloc = bufferizeToAllocation(
+      rewriter, maskOp.getMaskableOp(), memorySpace,
+      /*insertionPoint=*/insertionPoint ? insertionPoint : maskOp);
+
+  // Bufferize terminator.
+  rewriter.setInsertionPoint(yieldOp);
+  if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
+          rewriter, options)))
+    return nullptr;
+
+  // Erase dead to_tensor ops inside of the mask op. This is necessary because
+  // there only be one op (apart from the terminator) inside the mask op.
+  // TODO: Remove dead to_tensor ops more aggressively during bufferization.
+  SmallVector<Operation *> toTensorOps;
+  maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
+    if (toTensorOp->getUses().empty())
+      toTensorOps.push_back(toTensorOp.getOperation());
+  });
+  for (Operation *op : toTensorOps)
+    rewriter.eraseOp(op);
+
+  // Bufferize mask op.
+  SmallVector<OpOperand *> resultUses;
+  for (Value result : maskOp.getResults())
+    if (isa<TensorType>(result.getType()))
+      for (OpOperand &use : result.getUses())
+        resultUses.push_back(&use);
+  rewriter.setInsertionPoint(maskOp);
+  if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
+                 .bufferize(rewriter, options)))
+    return nullptr;
+
+  // Set "restrict" attribute, indicating that no other tensor aliases with
+  // this tensor. That is because we just allocated a new buffer for the tensor.
+  for (OpOperand *resultUse : resultUses) {
+    auto toTensorOp =
+        resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
+    assert(toTensorOp && "expected to_tensor op");
+    rewriter.updateRootInPlace(toTensorOp, [&]() {
+      toTensorOp.setRestrict(true);
+      toTensorOp.setWritable(true);
+    });
+  }
+
+  return alloc;
+}
+
 /// Lower tensor.from_elements to a sequence of chained tensor.insert.
 FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
     RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
@@ -329,12 +390,15 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
 }
 
 Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Operation *op,
-                                    Attribute memorySpace) {
+                                    Attribute memorySpace,
+                                    Operation *insertionPoint) {
   using namespace bufferization;
 
   // Call specialized overload for certain ops.
   if (auto padOp = dyn_cast<tensor::PadOp>(op))
     return bufferizeToAllocation(rewriter, padOp, memorySpace);
+  if (auto maskOp = dyn_cast<vector::MaskOp>(op))
+    return bufferizeToAllocation(rewriter, maskOp, memorySpace);
 
   // Only bufferizable ops are supported.
   auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
@@ -386,7 +450,7 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Operation *op,
 
   // Allocate buffers.
   OpBuilder::InsertionGuard g(rewriter);
-  rewriter.setInsertionPoint(op);
+  rewriter.setInsertionPoint(insertionPoint ? insertionPoint : op);
   SmallVector<Value> allocs;
   for (OpOperand *operand : outOfPlaceOperands) {
     Value alloc = createAllocationForTensor(rewriter, op->getLoc(),
@@ -401,6 +465,7 @@ Value linalg::bufferizeToAllocation(RewriterBase &rewriter, Operation *op,
   }
 
   // Bufferize the op.
+  rewriter.setInsertionPoint(op);
   if (failed(bufferizableOp.bufferize(rewriter, options)))
     return nullptr;
 

diff  --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index 99393ff75ff0c9..9f98d6728ed36c 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -100,3 +100,24 @@ transform.sequence failures(propagate) {
   // expected-error @below{{failed to bufferize operation}}
   %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
 }
+
+// -----
+
+// CHECK-LABEL: func @vector_mask(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>,
+//       CHECK:   %[[alloc:.*]] = memref.alloc(%{{.*}}) : memref<?xf32, 4>
+//       CHECK:   memref.tensor_store %[[t]], %[[alloc]]
+//       CHECK:   vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %[[alloc]]
+//       CHECK:   %[[r:.*]] = bufferization.to_tensor %[[alloc]] restrict writable
+//       CHECK:   memref.dealloc %[[alloc]]
+//       CHECK:   return %[[r]]
+func.func @vector_mask(%t: tensor<?xf32>, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor<?xf32> {
+  %r = vector.mask %m0 { vector.transfer_write %val, %t[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
+  return %r : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["vector.mask"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %2 = transform.structured.bufferize_to_allocation %0 {memory_space = 4} : !transform.any_op
+}


        


More information about the Mlir-commits mailing list