[Mlir-commits] [mlir] [MLIR][Bufferization] Address tensor cast canonicalizer interaction with `bufferization.materialize_in_destination` (PR #91274)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 6 14:28:02 PDT 2024
https://github.com/christopherbate created https://github.com/llvm/llvm-project/pull/91274
Attempts to address a bug pointed out in https://github.com/llvm/llvm-project/issues/91265
by relaxing the requirement for source/dest shapes to match in the
`bufferization.materialize_in_destination` operation. The relaxation
allows differences in static vs dynamic dims but still rejects cases
where the shapes are statically known to be different.
>From e597bdea6a417a2f81966eb112e8226142bd8c27 Mon Sep 17 00:00:00 2001
From: Christopher Bate <cbate at nvidia.com>
Date: Mon, 6 May 2024 14:38:56 -0600
Subject: [PATCH] [MLIR][Bufferization] Address tensor cast canonicalizer
interaction with `bufferization.materialize_in_destination`
Attempts to address a bug pointed out in https://github.com/llvm/llvm-project/issues/91265
by relaxing the requirement for source/dest shapes to match in the
`bufferization.materialize_in_destination` operation. The relaxation
allows differences in static vs dynamic dims but still rejects cases
where the shapes are statically known to be different.
---
.../Dialect/Bufferization/IR/BufferizationOps.td | 2 +-
.../Dialect/Bufferization/IR/BufferizationOps.cpp | 9 +++++++++
mlir/test/Dialect/Bufferization/canonicalize.mlir | 13 +++++++++++++
mlir/test/Dialect/Bufferization/invalid.mlir | 6 +++---
4 files changed, 26 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 4f609ddff9a413..15bfc26c17fb7a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -217,7 +217,7 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [
def Bufferization_MaterializeInDestinationOp
: Bufferization_Op<"materialize_in_destination",
- [AllShapesMatch<["source", "dest"]>,
+ [AllRanksMatch<["source", "dest"]>,
AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 0acb0c24ab313b..a6bf2a5a848777 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include <optional>
@@ -670,6 +671,14 @@ bool MaterializeInDestinationOp::operatesOnDisjointSubset(
}
LogicalResult MaterializeInDestinationOp::verify() {
+ // The shapes of `source` and `dest` must be compatible.
+ for (auto [srcDim, destDim] : llvm::zip_equal(
+ getSource().getType().getShape(), getDest().getType().getShape())) {
+ if (!ShapedType::isDynamic(srcDim) &&
+ !ShapedType::isDynamicShape(destDim) && srcDim != destDim)
+ return emitOpError("'source' and 'dest' must have compatible shapes");
+ }
+
if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
return emitOpError("'dest' must be a tensor or a memref");
if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index b6c0a0e25efe0e..a86a6b08889a00 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -388,3 +388,16 @@ func.func @negative_input() -> tensor<?x?x?xf16> {
%11 = bufferization.alloc_tensor(%c10, %idx-3, %idx27) : tensor<?x?x?xf16>
return %11 : tensor<?x?x?xf16>
}
+
+// -----
+
+func.func @materialize_in_destination_tensor_cast(%arg0: tensor<4xf32>, %arg1: index) -> tensor<?xf32> {
+ %0 = bufferization.alloc_tensor(%arg1) : tensor<?xf32>
+ %1 = tensor.cast %arg0 : tensor<4xf32> to tensor<?xf32>
+ %2 = bufferization.materialize_in_destination %1 in %0 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
+ return %2 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @materialize_in_destination_tensor_cast
+// CHECK: bufferization.materialize_in_destination
+// CHECK-SAME: : (tensor<4xf32>, tensor<?xf32>) -> tensor<?xf32>
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 4ebdb0a8f0490e..a1989cbc54dc84 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -43,9 +43,9 @@ func.func @invalid_writable_on_op() {
// -----
-func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
- // expected-error @below{{failed to verify that all of {source, dest} have same shape}}
- bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
+func.func @invalid_materialize_in_destination(%arg0: tensor<6xf32>, %arg1: tensor<5xf32>) {
+ // expected-error @below {{'bufferization.materialize_in_destination' op 'source' and 'dest' must have compatible shapes}}
+ bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<6xf32>, tensor<5xf32>) -> tensor<5xf32>
}
// -----
More information about the Mlir-commits
mailing list