[Mlir-commits] [mlir] [mlir][bufferization] Transfer `restrict` during empty tensor elimination (PR #68729)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 10 11:14:19 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Empty tensor elimination is looking for `bufferization.materialize_in_destination` ops with a `tensor.empty` source. It replaces the `tensor.empty` with a `bufferization.to_tensor restrict` of the memref destination. As part of this rewrite, the `restrict` keyword should be removed, so that no second `to_tensor restrict` op will be inserted. Such IR would be invalid. `bufferization.materialize_in_destination` with memref destination and without the `restrict` attribute are ignored by empty tensor elimination.
Also relax the verifier of `materialize_in_destination`. The `restrict` keyword is not generally needed because the op does not expose the buffer as a tensor.
---
Full diff: https://github.com/llvm/llvm-project/pull/68729.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+22-18)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+11-4)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+24)
- (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+1-8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 34a6f5d74b13956..c779d1f843d76a0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -246,28 +246,32 @@ def Bufferization_MaterializeInDestinationOp
rewrite IR such that a computation is performed directly in `dest` and no
memcpy is needed.
- If `dest` is a buffer, the `restrict` and `writable` attributes must be
- specified. These attributes have the same meaning as the respective
- attributes of `bufferization.to_tensor`. `writable` indicates that the
- `dest` buffer is considered writable. It does not make sense to materialize
- a computation in a read-only buffer, so `writable` is required. `restrict`
- indicates that this op is the only way for the tensor IR to access `dest`
- (or an alias thereof). E.g., there must be no other `to_tensor` ops with
- `dest` or with an alias of `dest`. Such IR is not supported by
- One-Shot Bufferize. Ops that have incorrect usage of `restrict` may
- bufferize incorrectly.
-
- Note: `restrict` and `writable` could be removed from this op because they
- must always be set for memref destinations. This op has these attributes to
- make clear the requirements on the `dest` operand in the op assembly format.
- Moreover, these requirements may be relaxed at some point in the future.
+ If `dest` is a buffer, the `writable` attribute must be specified and the
+ `restrict` keyword can be specified. These attributes have the same meaning
+ as the respective attributes of `bufferization.to_tensor`.
+
+ `writable` indicates that the `dest` buffer is considered writable. It does
+ not make sense to materialize a computation in a read-only buffer, so
+ `writable` is required.
+
+ `restrict` indicates that there is no `bufferization.to_tensor` op and no
+ other `bufferization.materialize_in_destination` op with `dest` (or an alias
+ thereof) and "restrict". Only ops with this attribute are considered for
+ "empty tensor elimination". As part of empty tensor elimination, a new
+ `to_tensor` op with `dest` may be inserted and the `restrict` attribute is
+ transferred from this op to the new `to_tensor` op. Having "restrict" on
+ this op guarantees that performing empty tensor elimination would not create
+ invalid IR (i.e., having multiple `to_tensor restrict` with aliasing
+ buffers).
+
+ Note: `writable` could be removed from this op because it must always be set
+ for memref destinations. This op has that attribute to make clear the
+ requirements on the `dest` operand in the op assembly format.
Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
same purpose, but since tensor dialect ops only indicate *what* should be
computed but not *where*, it could fold away, causing the computation to
- materialize in a different buffer. It is also possible that the
- `tensor.insert_slice` destination bufferizes out-of-place, which would also
- cause the computation to materialize in a buffer different buffer.
+ materialize in a different buffer.
}];
let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 738c8374d7add03..5716dcc9d905016 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -613,11 +613,19 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
return getDest();
}
+ // The "restrict" attribute is transferred from this op to the newly created
+ // to_tensor op. If this op does not the "restrict" attribute, the subset
+ // extraction cannot be built because there is no guarantee that there is no
+ // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
+ if (!getRestrict())
+ return {};
+
// Build a bufferization.to_tensor op.
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
assert(getRestrict() &&
"expected that ops with memrefs dest have 'restrict'");
- return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
+ setRestrict(false);
+ return builder.create<ToTensorOp>(loc, getDest(), /*restrict=*/true,
getWritable());
}
@@ -647,9 +655,8 @@ LogicalResult MaterializeInDestinationOp::verify() {
if (isa<BaseMemRefType>(getDest().getType()) &&
getOperation()->getNumResults() != 0)
return emitOpError("memref 'dest' implies zero results");
- if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
- return emitOpError("'restrict' must be specified if and only if the "
- "destination is of memref type");
+ if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
+ return emitOpError("'restrict' is valid only for memref destinations");
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'writable' must be specified if and only if the "
"destination is of memref type");
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 99b974b9ef3c67e..9a3e14b6d391782 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -eliminate-empty-tensors -empty-tensor-to-alloc-tensor -one-shot-bufferize="bufferize-function-boundaries" -cse -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -eliminate-empty-tensors | FileCheck %s --check-prefix=CHECK-ELIM
// CHECK: func @buffer_forwarding_conflict(
// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>
@@ -341,3 +342,26 @@ func.func @linalg_copy_empty() -> tensor<26xi32> {
%1 = linalg.copy ins(%0 : tensor<26xi32>) outs(%0 : tensor<26xi32>) -> tensor<26xi32>
return %1 : tensor<26xi32>
}
+
+// -----
+
+// CHECK-ELIM-LABEL: func @multiple_materialize_in_destination_buffer(
+// CHECK-ELIM-SAME: %[[m:.*]]: memref<5xf32>
+// CHECK-ELIM: tensor.empty
+// CHECK-ELIM: bufferization.to_tensor %[[m]] restrict writable
+// CHECK-ELIM: bufferization.materialize_in_destination {{.*}} in writable %[[m]]
+func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32, %f2: f32, %c: i1) {
+ %0 = tensor.empty() : tensor<5xf32>
+ %filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+
+ %1 = tensor.empty() : tensor<5xf32>
+ %filled2 = linalg.fill ins(%f2 : f32) outs(%1 : tensor<5xf32>) -> tensor<5xf32>
+
+ %selected = scf.if %c -> tensor<5xf32> {
+ scf.yield %filled : tensor<5xf32>
+ } else {
+ scf.yield %filled2 : tensor<5xf32>
+ }
+ bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
+ return
+}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index ce56f89c1f1bbe6..996d8430b84d48b 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -80,13 +80,6 @@ func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %a
// -----
-func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
- // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
- bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
-}
-
-// -----
-
func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
// expected-error @below{{memref 'dest' implies zero results}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
@@ -102,7 +95,7 @@ func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32
// -----
func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
- // expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
+ // expected-error @below{{'restrict' is valid only for memref destinations}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/68729
More information about the Mlir-commits
mailing list