[Mlir-commits] [mlir] [MLIR][Tensor] Fix source/dest type check in UnPackOp canonicalize (PR #106094)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 26 08:39:53 PDT 2024
https://github.com/yifeizh2 created https://github.com/llvm/llvm-project/pull/106094
Resolve the issue of checking wrong shape equality in unpack pack canonicalization.
>From a2a488c11aa5693771ac415a0eff37b9006f5628 Mon Sep 17 00:00:00 2001
From: "Zhang, Yifei" <yifei.zhang at intel.com>
Date: Mon, 26 Aug 2024 08:15:15 -0700
Subject: [PATCH] [MLIR][Tensor] Fix source/dest type check in UnPackOp
canonicalize
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
mlir/test/Dialect/Tensor/canonicalize.mlir | 13 +++++++++++++
2 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e11c6aaccf74dd..359330ee286042 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4439,7 +4439,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
PatternRewriter &rewriter) {
/// pack(unpack(x)) -> x
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
- if (packOp.getDestType() != unPackOp.getSourceType())
+ if (packOp.getSourceType() != unPackOp.getDestType())
return failure();
if (packOp.getPaddingValue() ||
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 458ff51be7462e..735790e5bd6c5e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2268,6 +2268,19 @@ func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) ->
// -----
+// CHECK: func.func @unpack_pack_with_padding_no_canonicalization(
+// CHECK: tensor.pack
+// CHECK: tensor.unpack
+func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> {
+ %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
+ %tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
+ %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
+ %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
+ return %unpacked : tensor<224x512xbf16>
+}
+
+// -----
+
// Chain NCnc -> NC -> NC -> NCnc
// CHECK: func.func @pack_unpack(
// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
More information about the Mlir-commits
mailing list