[Mlir-commits] [mlir] 8d08166 - [MLIR][Tensor] Fix source/dest type check in UnPackOp canonicalize (#106094)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 3 19:10:46 PDT 2024


Author: yifeizh2
Date: 2024-09-04T10:10:43+08:00
New Revision: 8d0816615f920b0783bafa903804b9e2a2fa4e91

URL: https://github.com/llvm/llvm-project/commit/8d0816615f920b0783bafa903804b9e2a2fa4e91
DIFF: https://github.com/llvm/llvm-project/commit/8d0816615f920b0783bafa903804b9e2a2fa4e91.diff

LOG: [MLIR][Tensor] Fix source/dest type check in UnPackOp canonicalize (#106094)

Fix `RankedTensorType` equality check in unpack op canonicalization.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e11c6aaccf74dd..996de530c255d4 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4203,7 +4203,7 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
 }
 
 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
-  // Fold an unpack(pack(x)) to x.
+  // Fold an pack(unpack(x)) to x.
   if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
     if (unPackOp.getSourceType() != packOp.getDestType())
       return failure();
@@ -4437,9 +4437,9 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
 
 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
                                      PatternRewriter &rewriter) {
-  /// pack(unpack(x)) -> x
+  /// unpack(pack(x)) -> x
   if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
-    if (packOp.getDestType() != unPackOp.getSourceType())
+    if (packOp.getSourceType() != unPackOp.getDestType())
       return failure();
     if (packOp.getPaddingValue() ||
         !hasSameInnerOuterAttribute(packOp, unPackOp) ||

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 458ff51be7462e..735790e5bd6c5e 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2268,6 +2268,19 @@ func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) ->
 
 // -----
 
+// CHECK: func.func @unpack_pack_with_padding_no_canonicalization(
+// CHECK:         tensor.pack
+// CHECK:         tensor.unpack
+func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> {
+  %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
+  %tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
+  %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
+  %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16> 
+  return %unpacked : tensor<224x512xbf16>
+}
+
+// -----
+
 // Chain NCnc -> NC -> NC -> NCnc
 // CHECK: func.func @pack_unpack(
 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,


        


More information about the Mlir-commits mailing list