[Mlir-commits] [mlir] [mlir][bufferization] Transfer `restrict` during empty tensor elimination (PR #68729)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 10 11:14:19 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Empty tensor elimination is looking for `bufferization.materialize_in_destination` ops with a `tensor.empty` source. It replaces the `tensor.empty` with a `bufferization.to_tensor restrict` of the memref destination. As part of this rewrite, the `restrict` keyword should be removed, so that no second `to_tensor restrict` op will be inserted. Such IR would be invalid. `bufferization.materialize_in_destination` with memref destination and without the `restrict` attribute are ignored by empty tensor elimination.

Also relax the verifier of `materialize_in_destination`. The `restrict` keyword is not generally needed because the op does not expose the buffer as a tensor.

---
Full diff: https://github.com/llvm/llvm-project/pull/68729.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+22-18) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+11-4) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+24) 
- (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+1-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 34a6f5d74b13956..c779d1f843d76a0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -246,28 +246,32 @@ def Bufferization_MaterializeInDestinationOp
     rewrite IR such that a computation is performed directly in `dest` and no
     memcpy is needed.
 
-    If `dest` is a buffer, the `restrict` and `writable` attributes must be
-    specified. These attributes have the same meaning as the respective
-    attributes of `bufferization.to_tensor`. `writable` indicates that the
-    `dest` buffer is considered writable. It does not make sense to materialize
-    a computation in a read-only buffer, so `writable` is required. `restrict`
-    indicates that this op is the only way for the tensor IR to access `dest`
-    (or an alias thereof). E.g., there must be no other `to_tensor` ops with
-    `dest` or with an alias of `dest`. Such IR is not supported by
-    One-Shot Bufferize. Ops that have incorrect usage of `restrict` may
-    bufferize incorrectly.
-
-    Note: `restrict` and `writable` could be removed from this op because they
-    must always be set for memref destinations. This op has these attributes to
-    make clear the requirements on the `dest` operand in the op assembly format.
-    Moreover, these requirements may be relaxed at some point in the future.
+    If `dest` is a buffer, the `writable` attribute must be specified and the
+    `restrict` keyword can be specified. These attributes have the same meaning
+    as the respective attributes of `bufferization.to_tensor`.
+
+    `writable` indicates that the `dest` buffer is considered writable. It does
+    not make sense to materialize a computation in a read-only buffer, so
+    `writable` is required.
+
+    `restrict` indicates that there is no `bufferization.to_tensor` op and no
+    other `bufferization.materialize_in_destination` op with `dest` (or an alias
+    thereof) and "restrict". Only ops with this attribute are considered for
+    "empty tensor elimination". As part of empty tensor elimination, a new
+    `to_tensor` op with `dest` may be inserted and the `restrict` attribute is
+    transferred from this op to the new `to_tensor` op. Having "restrict" on
+    this op guarantees that performing empty tensor elimination would not create
+    invalid IR (i.e., having multiple `to_tensor restrict` with aliasing
+    buffers).
+
+    Note: `writable` could be removed from this op because it must always be set
+    for memref destinations. This op has that attribute to make clear the
+    requirements on the `dest` operand in the op assembly format.
 
     Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
     same purpose, but since tensor dialect ops only indicate *what* should be
     computed but not *where*, it could fold away, causing the computation to
-    materialize in a different buffer. It is also possible that the
-    `tensor.insert_slice` destination bufferizes out-of-place, which would also
-    cause the computation to materialize in a buffer different buffer.
+    materialize in a different buffer.
   }];
 
   let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 738c8374d7add03..5716dcc9d905016 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -613,11 +613,19 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
     return getDest();
   }
 
+  // The "restrict" attribute is transferred from this op to the newly created
+  // to_tensor op. If this op does not the "restrict" attribute, the subset
+  // extraction cannot be built because there is no guarantee that there is no
+  // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
+  if (!getRestrict())
+    return {};
+
   // Build a bufferization.to_tensor op.
   assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
   assert(getRestrict() &&
          "expected that ops with memrefs dest have 'restrict'");
-  return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
+  setRestrict(false);
+  return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
                                     getWritable());
 }
 
@@ -647,9 +655,8 @@ LogicalResult MaterializeInDestinationOp::verify() {
   if (isa<BaseMemRefType>(getDest().getType()) &&
       getOperation()->getNumResults() != 0)
     return emitOpError("memref 'dest' implies zero results");
-  if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
-    return emitOpError("'restrict' must be specified if and only if the "
-                       "destination is of memref type");
+  if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
+    return emitOpError("'restrict' is valid only for memref destinations");
   if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
     return emitOpError("'writable' must be specified if and only if the "
                        "destination is of memref type");
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 99b974b9ef3c67e..9a3e14b6d391782 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries" -cse -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -eliminate-empty-tensors | FileCheck %s --check-prefix=CHECK-ELIM
 
 //      CHECK: func @buffer_forwarding_conflict(
 // CHECK-SAME:   %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
@@ -341,3 +342,26 @@ func.func @linalg_copy_empty() -> tensor<26xi32> {
   %1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32>
   return %1 : tensor<26xi32>
 }
+
+// -----
+
+// CHECK-ELIM-LABEL: func @multiple_materialize_in_destination_buffer(
+//  CHECK-ELIM-SAME:     %[[m:.*]]: memref<5xf32>
+//       CHECK-ELIM:   tensor.empty
+//       CHECK-ELIM:   bufferization.to_tensor %[[m]] restrict writable
+//       CHECK-ELIM:   bufferization.materialize_in_destination {{.*}} in writable %[[m]]
+func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32, %f2: f32, %c: i1) {
+  %0 = tensor.empty() : tensor<5xf32>
+  %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+
+  %1 = tensor.empty() : tensor<5xf32>
+  %filled2 = linalg.fill ins(%f2 : f32) outs(%1 : tensor<5xf32>) -> tensor<5xf32>
+
+  %selected = scf.if %c -> tensor<5xf32> {
+    scf.yield %filled : tensor<5xf32>
+  } else {
+    scf.yield %filled2 : tensor<5xf32>
+  }
+  bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
+  return
+}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index ce56f89c1f1bbe6..996d8430b84d48b 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -80,13 +80,6 @@ func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %a
 
 // -----
 
-func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
-  // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
-  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
-}
-
-// -----
-
 func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
   // expected-error @below{{memref 'dest' implies zero results}}
   bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
@@ -102,7 +95,7 @@ func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32
 // -----
 
 func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
-  // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+  // expected-error @below{{'restrict' is valid only for memref destinations}}
   bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/68729


More information about the Mlir-commits mailing list