[Mlir-commits] [mlir] [Linalg] Fix linalg.pack canonicalization priority issue (PR #160340)
Nirvedh Meshram
llvmlistbot at llvm.org
Tue Sep 23 09:56:12 PDT 2025
https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/160340
>From 5c5b64d7b880e3674a77413e9891b30f22635cb9 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Tue, 23 Sep 2025 09:26:53 -0700
Subject: [PATCH] fix linalg.pack canonicalization
Signed-off-by: Nirvedh Meshram <nirvedh at gmail.com>
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 15 +++++++--------
mlir/test/Dialect/Linalg/canonicalize.mlir | 3 ++-
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 578931e1351c6..4bc4d97697a21 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5583,14 +5583,13 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
// Fold an pack(unpack(x)) to x.
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
- if (unPackOp.getSourceType() != packOp.getDestType())
- return failure();
- if (packOp.getPaddingValue() ||
- !hasSameInnerOuterAttribute(packOp, unPackOp) ||
- !haveSameTiles(packOp, unPackOp))
- return failure();
- rewriter.replaceOp(packOp, unPackOp.getSource());
- return success();
+ if (unPackOp.getSourceType() == packOp.getDestType() &&
+ !packOp.getPaddingValue() &&
+ hasSameInnerOuterAttribute(packOp, unPackOp) &&
+ haveSameTiles(packOp, unPackOp)) {
+ rewriter.replaceOp(packOp, unPackOp.getSource());
+ return success();
+ }
}
// Fold optional PaddingValue operand away if padding is not needed.
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 5c5f7e861d37d..26d2d98572f47 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1756,10 +1756,11 @@ func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index)
// CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32>
// CHECK: return %[[T]] : tensor<16x16x8x8xf32>
func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
%tensor_empty = tensor.empty() : tensor<128x128xf32>
%unpacked = linalg.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
%tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32>
- %packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
+ %packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
return %packed : tensor<16x16x8x8xf32>
}
More information about the Mlir-commits
mailing list