[Mlir-commits] [mlir] 171cac9 - [mlir][tensor] Fold padding_value away for pack ops when possible. (#74005)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 1 11:13:10 PST 2023


Author: Han-Chung Wang
Date: 2023-12-01T11:12:58-08:00
New Revision: 171cac95a7eb1526f4d18bf8f654275656183ce4

URL: https://github.com/llvm/llvm-project/commit/171cac95a7eb1526f4d18bf8f654275656183ce4
DIFF: https://github.com/llvm/llvm-project/commit/171cac95a7eb1526f4d18bf8f654275656183ce4.diff

LOG: [mlir][tensor] Fold padding_value away for pack ops when possible. (#74005)

If we can infer statically that there are no incomplete tiles, we can
remove the optional padding operand.

Fixes https://github.com/openxla/iree/issues/15417

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5bfcb35127b5267..cd9b82d2c553fae 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -3800,17 +3801,39 @@ static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
   return true;
 }
 
-/// Fold an unpack(pack(x)) to x.
+/// Returns true if the pack op does not need a padding value.
+static bool paddingIsNotNeeded(PackOp op) {
+  auto srcType = op.getSourceType();
+  if (llvm::any_of(op.getInnerDimsPos(),
+                   [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
+    return false;
+  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
+    return false;
+  return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
+                                      op.getMixedTiles());
+}
+
 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
-  UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
-  if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
-    return failure();
-  if (packOp.getPaddingValue() ||
-      !hasSameInnerOuterAttribute(packOp, unPackOp) ||
-      !haveSameTiles(packOp, unPackOp))
-    return failure();
-  rewriter.replaceOp(packOp, unPackOp.getSource());
-  return success();
+  // Fold an unpack(pack(x)) to x.
+  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
+    if (unPackOp.getSourceType() != packOp.getDestType())
+      return failure();
+    if (packOp.getPaddingValue() ||
+        !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+        !haveSameTiles(packOp, unPackOp))
+      return failure();
+    rewriter.replaceOp(packOp, unPackOp.getSource());
+    return success();
+  }
+
+  // Fold optional PaddingValue operand away if padding is not needed.
+  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
+    rewriter.startRootUpdate(packOp);
+    packOp.getPaddingValueMutable().clear();
+    rewriter.finalizeRootUpdate(packOp);
+    return success();
+  }
+  return failure();
 }
 
 template <typename PackOrUnpackOp>

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 41bfd6fe7b6eedc..580c1db6070201f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -719,6 +719,71 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1
 
 // -----
 
+func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
+  %pack = tensor.pack %arg0
+    padding_value(%cst : f32)
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 1]
+    into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32>
+  return %pack : tensor<31250x1200x16x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack
+// CHECK-NOT:     padding_value
+
+// -----
+
+func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
+  %pack = tensor.pack %arg0
+    padding_value(%cst : f32)
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 1]
+    into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32>
+  return %pack : tensor<31250x1200x16x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack_negative1
+// CHECK:         tensor.pack
+// CHECK-SAME:      padding_value
+
+// -----
+
+func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = tensor.pack %arg0
+    padding_value(%cst : f32)
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [1, 0]
+    inner_tiles = [16, 1]
+    into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32>
+  return %pack : tensor<?x1200x16x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack_negative2
+// CHECK:         tensor.pack
+// CHECK-SAME:      padding_value
+
+// -----
+
+func.func @fold_padding_value_pack_negative3(%arg0: tensor<1200x500000xf32>, %arg1: tensor<?x1200x?x1xf32>, %tile : index) -> tensor<?x1200x?x1xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %pack = tensor.pack %arg0
+    padding_value(%cst : f32)
+    outer_dims_perm = [1, 0]
+    inner_dims_pos = [1, 0]
+    inner_tiles = [%tile, 1]
+    into %arg1 : tensor<1200x500000xf32> -> tensor<?x1200x?x1xf32>
+  return %pack : tensor<?x1200x?x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack_negative3
+// CHECK:         tensor.pack
+// CHECK-SAME:      padding_value
+
+// -----
+
 // CHECK-LABEL: func @fold_unpack_constant_splat
 //   CHECK-NOT: tensor.unpack
 //       CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>


        


More information about the Mlir-commits mailing list