[Mlir-commits] [mlir] [mlir][bufferization] Allow mixed static/dynamic shapes in `materialize_in_destination` op (PR #92681)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun May 19 02:09:52 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit relaxes the verifier of `bufferization.materialize_in_destination` such that mixed static/dynamic dimensions are allowed for the source and destination operands. E.g., `tensor<5xf32>` and `tensor<?xf32>` are now compatible, but it is assumed that the dynamic dimension is `5` at runtime.

This commit fixes #<!-- -->91265.

---
Full diff: https://github.com/llvm/llvm-project/pull/92681.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+4-5) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+18) 
- (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+10-3) 
- (modified) mlir/test/Dialect/Bufferization/ops.mlir (+5-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a41..1c70a4b8df925 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -217,8 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 
 def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
-        [AllShapesMatch<["source", "dest"]>,
-         AllElementTypesMatch<["source", "dest"]>,
+        [AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
          DeclareOpInterfaceMethods<SubsetOpInterface,
@@ -239,9 +238,9 @@ def Bufferization_MaterializeInDestinationOp
     memref, `source` materializes in `dest`, which is already a buffer. The op
     has no results in that case.
 
-    `source`, `dest` and `result` (if present) must have the same shape and
-    element type. If the op has a result, the types of `result` and `dest` must
-    match exactly (e.g., including any tensor encodings).
+    `source`, `dest` and `result` (if present) must have the same runtime shape
+    and element type. If the op has a result, the types of `result` and `dest`
+    must match exactly (e.g., including any tensor encodings).
 
     By default, this op bufferizes to a memcpy from the future buffer of the
     `source` tensor to the future buffer of the `dest` tensor or to the `dest`
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313..3b7b412842bfb 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -686,6 +686,24 @@ LogicalResult MaterializeInDestinationOp::verify() {
   if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
     return emitOpError("'writable' must be specified if and only if the "
                        "destination is of memref type");
+  TensorType srcType = getSource().getType();
+  ShapedType destType = cast<ShapedType>(getDest().getType());
+  if (srcType.hasRank() != destType.hasRank())
+    return emitOpError("source/destination shapes are incompatible");
+  if (srcType.hasRank()) {
+    if (srcType.getRank() != destType.getRank())
+      return emitOpError("rank mismatch between source and destination shape");
+    for (auto [src, dest] :
+         llvm::zip(srcType.getShape(), destType.getShape())) {
+      if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
+        // Cannot verify dynamic dimension size. Assume that that they match at
+        // runtime.
+        continue;
+      }
+      if (src != dest)
+        return emitOpError("source/destination shapes are incompatible");
+    }
+  }
   return success();
 }
 
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 4ebdb0a8f0490..2c8807b66de74 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -43,9 +43,16 @@ func.func @invalid_writable_on_op() {
 
 // -----
 
-func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
-  // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
-  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+func.func @invalid_materialize_in_destination(%arg0: tensor<4xf32>, %arg1: tensor<5xf32>) {
+  // expected-error @below{{source/destination shapes are incompatible}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<4xf32>, tensor<5xf32>) -> tensor<5xf32>
+}
+
+// -----
+
+func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>) {
+  // expected-error @below{{rank mismatch between source and destination shape}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir
index d4bda0632189d..ad4a66c1b7978 100644
--- a/mlir/test/Dialect/Bufferization/ops.mlir
+++ b/mlir/test/Dialect/Bufferization/ops.mlir
@@ -59,12 +59,15 @@ func.func @test_dealloc_tensor_op(%arg0: tensor<4xi32>) {
 }
 
 // CHECK-LABEL: func @test_materialize_in_destination_op
-func.func @test_materialize_in_destination_op(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>)
-    -> tensor<?xf32> {
+func.func @test_materialize_in_destination_op(
+    %arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: memref<?xf32, 3>,
+    %arg4: tensor<5xf32>) -> tensor<?xf32> {
   // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   %1 = bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
   // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, memref<?xf32, 3>) -> ()
   bufferization.materialize_in_destination %arg0 in restrict writable %arg2 : (tensor<?xf32>, memref<?xf32, 3>) -> ()
+  // CHECK: bufferization.materialize_in_destination {{.*}} : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+  %2 = bufferization.materialize_in_destination %arg0 in %arg4 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
   return %1 : tensor<?xf32>
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/92681


More information about the Mlir-commits mailing list