[Mlir-commits] [mlir] [mlir][tensor] Fold padding_value away for pack ops when possible. (PR #74005)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Nov 30 16:02:37 PST 2023
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/74005
>From 943e71f9ef37e8d9c32598323974de42210f378f Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 30 Nov 2023 15:28:48 -0800
Subject: [PATCH 1/2] [mlir][tensor] Fold padding_value away for pack ops when
possible.
If we can infer statically that there are no incomplete tiles, we can
remove the optional padding operand.
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 40 +++++++++++++-----
mlir/test/Dialect/Tensor/canonicalize.mlir | 49 ++++++++++++++++++++++
2 files changed, 79 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5bfcb35127b5267..30b719924dea56f 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3800,17 +3800,37 @@ static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
return true;
}
-/// Fold an unpack(pack(x)) to x.
+/// Returns true if the pack op does not need a padding value.
+static bool paddingIsNotNeeded(PackOp op) {
+ auto srcType = op.getSourceType();
+ if (llvm::any_of(op.getInnerDimsPos(),
+ [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
+ return false;
+ return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
+ op.getMixedTiles());
+}
+
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
- UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
- if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
- return failure();
- if (packOp.getPaddingValue() ||
- !hasSameInnerOuterAttribute(packOp, unPackOp) ||
- !haveSameTiles(packOp, unPackOp))
- return failure();
- rewriter.replaceOp(packOp, unPackOp.getSource());
- return success();
+ // Fold an unpack(pack(x)) to x.
+ if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
+ if (unPackOp.getSourceType() != packOp.getDestType())
+ return failure();
+ if (packOp.getPaddingValue() ||
+ !hasSameInnerOuterAttribute(packOp, unPackOp) ||
+ !haveSameTiles(packOp, unPackOp))
+ return failure();
+ rewriter.replaceOp(packOp, unPackOp.getSource());
+ return success();
+ }
+
+ // Fold optional PaddingValue operand away if padding is not needed.
+ if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
+ rewriter.startRootUpdate(packOp);
+ packOp.getPaddingValueMutable().clear();
+ rewriter.finalizeRootUpdate(packOp);
+ return success();
+ }
+ return failure();
}
template <typename PackOrUnpackOp>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 41bfd6fe7b6eedc..68ce35c866eb034 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -719,6 +719,55 @@ func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x1
// -----
+func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
+ %pack = tensor.pack %arg0
+ padding_value(%cst : f32)
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [1, 0]
+ inner_tiles = [16, 1]
+ into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32>
+ return %pack : tensor<31250x1200x16x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack
+// CHECK-NOT: padding_value
+
+// -----
+
+func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
+ %pack = tensor.pack %arg0
+ padding_value(%cst : f32)
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [1, 0]
+ inner_tiles = [16, 1]
+ into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32>
+ return %pack : tensor<31250x1200x16x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack_negative1
+// CHECK: tensor.pack
+// CHECK-SAME: padding_value
+
+// -----
+
+func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %arg0
+ padding_value(%cst : f32)
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [1, 0]
+ inner_tiles = [16, 1]
+ into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32>
+ return %pack : tensor<?x1200x16x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack_negative2
+// CHECK: tensor.pack
+// CHECK-SAME: padding_value
+
+// -----
+
// CHECK-LABEL: func @fold_unpack_constant_splat
// CHECK-NOT: tensor.unpack
// CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
>From 66584320a33e989006a44ffd7e1fa3775d0aa403 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 30 Nov 2023 16:01:58 -0800
Subject: [PATCH 2/2] add a check about inner tile size
---
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 +++
mlir/test/Dialect/Tensor/canonicalize.mlir | 16 ++++++++++++++++
2 files changed, 19 insertions(+)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 30b719924dea56f..cd9b82d2c553fae 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -3806,6 +3807,8 @@ static bool paddingIsNotNeeded(PackOp op) {
if (llvm::any_of(op.getInnerDimsPos(),
[&](int64_t pos) { return srcType.isDynamicDim(pos); }))
return false;
+ if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
+ return false;
return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
op.getMixedTiles());
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 68ce35c866eb034..580c1db6070201f 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -768,6 +768,22 @@ func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: t
// -----
+func.func @fold_padding_value_pack_negative3(%arg0: tensor<1200x500000xf32>, %arg1: tensor<?x1200x?x1xf32>, %tile : index) -> tensor<?x1200x?x1xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %pack = tensor.pack %arg0
+ padding_value(%cst : f32)
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [1, 0]
+ inner_tiles = [%tile, 1]
+ into %arg1 : tensor<1200x500000xf32> -> tensor<?x1200x?x1xf32>
+ return %pack : tensor<?x1200x?x1xf32>
+}
+// CHECK-LABEL: func @fold_padding_value_pack_negative3
+// CHECK: tensor.pack
+// CHECK-SAME: padding_value
+
+// -----
+
// CHECK-LABEL: func @fold_unpack_constant_splat
// CHECK-NOT: tensor.unpack
// CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
More information about the Mlir-commits
mailing list