[Mlir-commits] [mlir] [Linalg] Fix linalg.pack canonicalization priority issue (PR #160340)

Nirvedh Meshram llvmlistbot at llvm.org
Tue Sep 23 09:56:12 PDT 2025


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/160340

>From 5c5b64d7b880e3674a77413e9891b30f22635cb9 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Tue, 23 Sep 2025 09:26:53 -0700
Subject: [PATCH] fix linalg.pack canonicalization

Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp   | 15 +++++++--------
 mlir/test/Dialect/Linalg/canonicalize.mlir |  3 ++-
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 578931e1351c6..4bc4d97697a21 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5583,14 +5583,13 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
   // Fold an pack(unpack(x)) to x.
   if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
-    if (unPackOp.getSourceType() != packOp.getDestType())
-      return failure();
-    if (packOp.getPaddingValue() ||
-        !hasSameInnerOuterAttribute(packOp, unPackOp) ||
-        !haveSameTiles(packOp, unPackOp))
-      return failure();
-    rewriter.replaceOp(packOp, unPackOp.getSource());
-    return success();
+    if (unPackOp.getSourceType() == packOp.getDestType() &&
+        !packOp.getPaddingValue() &&
+        hasSameInnerOuterAttribute(packOp, unPackOp) &&
+        haveSameTiles(packOp, unPackOp)) {
+      rewriter.replaceOp(packOp, unPackOp.getSource());
+      return success();
+    }
   }
 
   // Fold optional PaddingValue operand away if padding is not needed.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5c5f7e861d37d..26d2d98572f47 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1756,10 +1756,11 @@ func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index)
 // CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32>
 // CHECK: return %[[T]] : tensor<16x16x8x8xf32>
 func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
   %tensor_empty = tensor.empty() : tensor<128x128xf32>
   %unpacked = linalg.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
   %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32>
-  %packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+  %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
   return %packed : tensor<16x16x8x8xf32>
 }
 



More information about the Mlir-commits mailing list