[Mlir-commits] [mlir] [mlir][tensor] Fix tensor::PackOp fold() handling of padding value (PR #87296)

Han-Chung Wang llvmlistbot at llvm.org
Tue Apr 2 09:39:54 PDT 2024


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/87296

>From dc6004e270466219b42c7273fdd87b0478ea971b Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 1 Apr 2024 17:01:16 -0700
Subject: [PATCH 1/3] [mlir][tensor] Fix a bug in constant packing folding.

We can't just check if it is a splat constant or not. We should also
check if the value match.
---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp   | 13 ++++++---
 mlir/test/Dialect/Tensor/canonicalize.mlir | 33 ++++++++++++++++++++++
 2 files changed, 42 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 38a9ad60bb7948..8dc1ef67ce65c5 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1069,9 +1069,11 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// Try to remove a tensor operation if it would only reshape a constant.
 /// Removes the op and replaces the constant with a new constant of the result
 /// shape.
-static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
-                                          TensorType result) {
-  if (source && source.isSplat() && result.hasStaticShape())
+static OpFoldResult
+reshapeConstantSource(DenseElementsAttr source, TensorType result,
+                      std::optional<Attribute> cst = std::nullopt) {
+  if (source && source.isSplat() && result.hasStaticShape() &&
+      (!cst.has_value() || source.getSplatValue<Attribute>() == cst.value()))
     return source.resizeSplat(result);
 
   return {};
@@ -4143,9 +4145,12 @@ bool PackOp::isLikePad() {
 }
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
+  std::optional<Attribute> paddingValue;
+  if (adaptor.getPaddingValue())
+    paddingValue = adaptor.getPaddingValue();
   if (OpFoldResult reshapedSource = reshapeConstantSource(
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
-          getResult().getType()))
+          adaptor.getDestType(), paddingValue))
     return reshapedSource;
   return {};
 }
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 9ab54fe9c133db..ac365c9d297e88 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -830,6 +830,39 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1
 
 // -----
 
+// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
+//   CHECK-NOT: tensor.pack
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
+func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+  %pad = arith.constant 1.000000e-01 : f32
+  %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+  %0 = tensor.pack %cst
+    padding_value(%pad : f32)
+    outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
+    inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+  return %0 : tensor<8x16x8x32xf32>
+}
+
+
+// -----
+
+// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
+//       CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+//       CHECK: tensor.pack
+func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
+  %pad = arith.constant 0.0 : f32
+  %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
+  %0 = tensor.pack %cst
+    padding_value(%pad : f32)
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [0, 1]
+    inner_tiles = [8, 32]
+    into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
+  return %0 : tensor<8x16x8x32xf32>
+}
+
+// -----
+
 func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
   %cst = arith.constant 0.000000e+00 : f32
   %0 = tensor.empty() : tensor<31250x1200x16x1xf32>

>From ac9d0f77226fc2d2edfb3060a442907515f24e7d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 1 Apr 2024 17:07:04 -0700
Subject: [PATCH 2/3] fix compile error

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8dc1ef67ce65c5..609b030b84f0bb 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4150,7 +4150,7 @@ OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
     paddingValue = adaptor.getPaddingValue();
   if (OpFoldResult reshapedSource = reshapeConstantSource(
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
-          adaptor.getDestType(), paddingValue))
+          getDestType(), paddingValue))
     return reshapedSource;
   return {};
 }

>From e1d25b0ea49b0c63561668376fd933326d63728d Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Tue, 2 Apr 2024 09:39:37 -0700
Subject: [PATCH 3/3] address comments

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 609b030b84f0bb..0ce40e81371209 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1068,7 +1068,8 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 /// Try to remove a tensor operation if it would only reshape a constant.
 /// Removes the op and replaces the constant with a new constant of the result
-/// shape.
+/// shape. When an optional cst attribute is passed, it is reshaped only if the
+/// splat value matches the value in the attribute.
 static OpFoldResult
 reshapeConstantSource(DenseElementsAttr source, TensorType result,
                       std::optional<Attribute> cst = std::nullopt) {
@@ -4146,8 +4147,8 @@ bool PackOp::isLikePad() {
 
 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
   std::optional<Attribute> paddingValue;
-  if (adaptor.getPaddingValue())
-    paddingValue = adaptor.getPaddingValue();
+  if (auto pad = adaptor.getPaddingValue())
+    paddingValue = pad;
   if (OpFoldResult reshapedSource = reshapeConstantSource(
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
           getDestType(), paddingValue))



More information about the Mlir-commits mailing list