[Mlir-commits] [mlir] [MLIR][Bufferization] Address tensor cast canonicalizer interaction with `bufferization.materialize_in_destination` (PR #91274)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 6 14:28:31 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chris (christopherbate)

<details>
<summary>Changes</summary>

Attempts to address a bug pointed out in https://github.com/llvm/llvm-project/issues/91265
by relaxing the requirement for source/dest shapes to match in the
`bufferization.materialize_in_destination` operation. The relaxation
allows differences in static vs dynamic dims but still rejects cases
where the shapes are statically known to be different.


---
Full diff: https://github.com/llvm/llvm-project/pull/91274.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+1-1) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+9) 
- (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+13) 
- (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a413..15bfc26c17fb7a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -217,7 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
 
 def Bufferization_MaterializeInDestinationOp
     : Bufferization_Op<"materialize_in_destination",
-        [AllShapesMatch<["source", "dest"]>,
+        [AllRanksMatch<["source", "dest"]>,
          AllElementTypesMatch<["source", "dest"]>,
          BufferizableOpInterface, DestinationStyleOpInterface,
          DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313b..a6bf2a5a848777 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/Matchers.h"
 #include <optional>
 
@@ -670,6 +671,14 @@ bool MaterializeInDestinationOp::operatesOnDisjointSubset(
 }
 
 LogicalResult MaterializeInDestinationOp::verify() {
+  // The shapes of `source` and `dest` must be compatible.
+  for (auto [srcDim, destDim] : llvm::zip_equal(
+           getSource().getType().getShape(), getDest().getType().getShape())) {
+    if (!ShapedType::isDynamic(srcDim) &&
+        !ShapedType::isDynamicShape(destDim) && srcDim != destDim)
+      return emitOpError("'source' and 'dest' must have compatible shapes");
+  }
+
   if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
     return emitOpError("'dest' must be a tensor or a memref");
   if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index b6c0a0e25efe0e..a86a6b08889a00 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -388,3 +388,16 @@ func.func @negative_input() -> tensor<?x?x?xf16> {
   %11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
   return %11 : tensor<?x?x?xf16>
 }
+
+// -----
+
+func.func @materialize_in_destination_tensor_cast(%arg0: tensor<4xf32>, %arg1: index) -> tensor<?xf32> {
+  %0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
+  %1 = tensor.cast %arg0 : tensor<4xf32> to tensor<?xf32>
+  %2 = bufferization.materialize_in_destination %1 in %0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+  return %2 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @materialize_in_destination_tensor_cast
+//       CHECK:   bufferization.materialize_in_destination
+//  CHECK-SAME:    : (tensor<4xf32>, tensor<?xf32>) -> tensor<?xf32>
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 4ebdb0a8f0490e..a1989cbc54dc84 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -43,9 +43,9 @@ func.func @invalid_writable_on_op() {
 
 // -----
 
-func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
-  // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
-  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+func.func @invalid_materialize_in_destination(%arg0: tensor<6xf32>, %arg1: tensor<5xf32>) {
+  // expected-error @below {{'bufferization.materialize_in_destination' op 'source' and 'dest' must have compatible shapes}}
+  bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<6xf32>, tensor<5xf32>) -> tensor<5xf32>
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/91274


More information about the Mlir-commits mailing list