[Mlir-commits] [mlir] [mlir][bufferization] Allow mixed static/dynamic shapes in `materialize_in_destination` op (PR #92681)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun May 19 02:09:52 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit relaxes the verifier of `bufferization.materialize_in_destination` such that mixed static/dynamic dimensions are allowed for the source and destination operands. E.g., `tensor<5xf32>` and `tensor<?xf32>` are now compatible, but it is assumed that the dynamic dimension is `5` at runtime.
This commit fixes #<!-- -->91265.
---
Full diff: https://github.com/llvm/llvm-project/pull/92681.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+4-5)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+18)
- (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+10-3)
- (modified) mlir/test/Dialect/Bufferization/ops.mlir (+5-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..1c70a4b8df925 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -217,8 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
def Bufferization_MaterializeInDestinationOp
: Bufferization_Op<"materialize_in_destination",
- [AllShapesMatch<["source", "dest"]>,
- AllElementTypesMatch<["source", "dest"]>,
+ [AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<SubsetOpInterface,
@@ -239,9 +238,9 @@ def Bufferization_MaterializeInDestinationOp
memref, `source` materializes in `dest`, which is already a buffer. The op
has no results in that case.
- `source`, `dest` and `result` (if present) must have the same shape and
- element type. If the op has a result, the types of `result` and `dest` must
- match exactly (e.g., including any tensor encodings).
+ `source`, `dest` and `result` (if present) must have the same runtime shape
+ and element type. If the op has a result, the types of `result` and `dest`
+ must match exactly (e.g., including any tensor encodings).
By default, this op bufferizes to a memcpy from the future buffer of the
`source` tensor to the future buffer of the `dest` tensor or to the `dest`
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..3b7b412842bfb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -686,6 +686,24 @@ LogicalResult MaterializeInDestinationOp::verify() {
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'writable' must be specified if and only if the "
"destination is of memref type");
+ TensorType srcType = getSource().getType();
+ ShapedType destType = cast<ShapedType>(getDest().getType());
+ if (srcType.hasRank() != destType.hasRank())
+ return emitOpError("source/destination shapes are incompatible");
+ if (srcType.hasRank()) {
+ if (srcType.getRank() != destType.getRank())
+ return emitOpError("rank mismatch between source and destination shape");
+ for (auto [src, dest] :
+ llvm::zip(srcType.getShape(), destType.getShape())) {
+ if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
+ // Cannot verify dynamic dimension size. Assume that that they match at
+ // runtime.
+ continue;
+ }
+ if (src != dest)
+ return emitOpError("source/destination shapes are incompatible");
+ }
+ }
return success();
}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 4ebdb0a8f0490..2c8807b66de74 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -43,9 +43,16 @@ func.func @invalid_writable_on_op() {
// -----
-func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
- // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
- bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+func.func @invalid_materialize_in_destination(%arg0: tensor<4xf32>, %arg1: tensor<5xf32>) {
+ // expected-error @below{{source/destination shapes are incompatible}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<4xf32>, tensor<5xf32>) -> tensor<5xf32>
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>) {
+ // expected-error @below{{rank mismatch between source and destination shape}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32>
}
// -----
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index d4bda0632189d..ad4a66c1b7978 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -59,12 +59,15 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
}
// CHECK-LABEL: func @test_materialize_in_destination_op
-func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
- -> tensor<?xf32> {
+func.func @test_materialize_in_destination_op(
+ %arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>,
+ %arg4: tensor<5xf32>) -> tensor<?xf32> {
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
+ // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+ %2 = bufferization.materialize_in_destination %arg0 in %arg4 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
return %1 : tensor<?xf32>
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/92681
More information about the Mlir-commits
mailing list