[Mlir-commits] [mlir] 30d542f - [MLIR][Tensor] Introduce a pattern to propagate through `tensor.pad`
Lorenzo Chelini
llvmlistbot at llvm.org
Tue Feb 14 23:49:01 PST 2023
Author: Lorenzo Chelini
Date: 2023-02-15T08:48:55+01:00
New Revision: 30d542f9b2adf10ee2cb7e07877ad7f0cfdbfea2
URL: https://github.com/llvm/llvm-project/commit/30d542f9b2adf10ee2cb7e07877ad7f0cfdbfea2
DIFF: https://github.com/llvm/llvm-project/commit/30d542f9b2adf10ee2cb7e07877ad7f0cfdbfea2.diff
LOG: [MLIR][Tensor] Introduce a pattern to propagate through `tensor.pad`
Introduce a pattern to 'push down' a `tensor.unpack` through a
`tensor.pad`. The propagation happens if the unpack does not touch the
padded dimensions.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D143907
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 1b6d1d247a2d..bf5e64ba1f34 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -465,10 +465,69 @@ struct PushDownUnPackOpThroughElemGenericOp
}
};
+/// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
+/// add as many zero padding dimensions in `high` and `low` based on the number
+/// of point loops.
+struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ tensor::UnPackOp unpackOp =
+ padOp.getSource().getDefiningOp<tensor::UnPackOp>();
+ if (!unpackOp)
+ return failure();
+
+ Location loc = padOp.getLoc();
+ // Bail out if one of the padded dimension is a tiled one.
+ llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
+ ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
+ llvm::SmallBitVector innerDims(paddedDims.size());
+ for (int64_t dim : innerDimsPos)
+ innerDims.flip(dim);
+ if (paddedDims.anyCommon(innerDims))
+ return failure();
+
+ Value paddingVal = padOp.getConstantPaddingValue();
+ if (!paddingVal)
+ return failure();
+
+ // If we have `outer_dims_perms` we need to adjust the padded dimensions.
+ ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+ SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
+ if (!outerDimsPerm.empty()) {
+ applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
+ applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
+ }
+ // Add zero padding for the point loops.
+ size_t pointLoopsSize = innerDimsPos.size();
+ lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+ highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
+ paddingVal, padOp.getNofold());
+
+ // Inject the tensor.unpack right after the packed padOp.
+ Value outputUnPack = rewriter.create<tensor::EmptyOp>(
+ loc, padOp.getResultType().getShape(),
+ padOp.getResultType().getElementType());
+
+ Value replacement = rewriter.create<tensor::UnPackOp>(
+ loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
+ unpackOp.getMixedTiles(), outerDimsPerm);
+ rewriter.replaceOp(padOp, replacement);
+ return success();
+ }
+};
+
} // namespace
void mlir::linalg::populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns) {
- patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern,
- PushDownUnPackOpThroughElemGenericOp>(patterns.getContext());
+ patterns
+ .insert<BubbleUpPackOpThroughElemGenericOpPattern,
+ PushDownUnPackOpThroughElemGenericOp, PushDownUnPackThroughPadOp>(
+ patterns.getContext());
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index b699b3d19cee..32190cac8d2f 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -471,4 +471,70 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
// CHECK-SAME: outs(%[[DEST]]
// CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
+// CHECK-SAME: into %[[ARG0_UNPACK_EMPTY]]
+
+// -----
+
+func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+ %padded = tensor.pad %1 low[0, 1, 1, 0] high[0, 1, 1, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x56x56x64xf32> to tensor<1x58x58x64xf32>
+ return %padded : tensor<1x58x58x64xf32>
+}
+
+// CHECK: func.func @pad_valid_propagation(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
+
+// -----
+
+func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+ %padded = tensor.pad %1 low[1, 1, 1, 0] high[0, 1, 1, 0] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x56x56x64xf32> to tensor<2x58x58x64xf32>
+ return %padded : tensor<2x58x58x64xf32>
+}
+
+// CHECK: func.func @pad_valid_propagation(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32>
+
+// -----
+
+func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x66xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<1x56x56x64xf32>
+ %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+ %padded = tensor.pad %1 low[0, 1, 1, 1] high[0, 1, 1, 1] {
+ ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<1x56x56x64xf32> to tensor<1x58x58x66xf32>
+ return %padded : tensor<1x58x58x66xf32>
+}
+
+// CHECK: func.func @pad_along_unpacked_dim(
+// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME: into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]
More information about the Mlir-commits
mailing list