[Mlir-commits] [mlir] 6538808 - [mlir][tensor] Add patterns that fold ops into pack and unpack ops.
Hanhan Wang
llvmlistbot at llvm.org
Wed Jan 11 13:52:02 PST 2023
Author: Hanhan Wang
Date: 2023-01-11T13:51:49-08:00
New Revision: 65388086e68245c11c5acf5cd6b3570d8e4d11bf
URL: https://github.com/llvm/llvm-project/commit/65388086e68245c11c5acf5cd6b3570d8e4d11bf
DIFF: https://github.com/llvm/llvm-project/commit/65388086e68245c11c5acf5cd6b3570d8e4d11bf.diff
LOG: [mlir][tensor] Add patterns that fold ops into pack and unpack ops.
The tensor.pack ops have pad semantic, so we can fold pad + pack into
pack when
1. They have the same padding values or the pack op does not have
padding values.
2. The pad op does not have low paddings.
The tensor.unpack ops have extract_slice semantic, so we can fold unpack
+ extract_slice into unpack when
1. All the offsets are 0s.
2. All the strides are 1s.
Reviewed By: tyb0807
Differential Revision: https://reviews.llvm.org/D141099
Added:
mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
Modified:
mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 430842b4584ff..01985c943527c 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -44,6 +44,11 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
/// tensor.[extract_slice|cast|expand_shape|collapse_shape].
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns that fold operations like `tensor.pad`
+/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
+/// respectively.
+void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 216fc8edb9258..5ed3d97b2719f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
Bufferize.cpp
EmptyOpPatterns.cpp
ExtractSliceFromReshapeUtils.cpp
+ FoldIntoPackAndUnpackPatterns.cpp
MergeConsecutiveInsertExtractSlicePatterns.cpp
ReshapePatterns.cpp
SplitPaddingPatterns.cpp
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
new file mode 100644
index 0000000000000..744e49edcb6c9
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -0,0 +1,87 @@
+//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace tensor {
+namespace {
+
+static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
+ return llvm::all_of(
+ ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
+}
+
+/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
+/// the pad op has zero low paddings, or if `pack` has no padding values.
+struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
+ using OpRewritePattern<PackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PackOp packOp,
+ PatternRewriter &rewriter) const override {
+ auto padOp = packOp.getSource().getDefiningOp<PadOp>();
+
+ if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
+ return failure();
+
+ Value constantPaddingValue = padOp.getConstantPaddingValue();
+ if (!constantPaddingValue)
+ return failure();
+
+ if (auto paddingValue = packOp.getPaddingValue())
+ if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<PackOp>(
+ packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
+ packOp.getMixedTiles(), constantPaddingValue,
+ packOp.getOuterDimsPerm());
+ return success();
+ }
+};
+
+/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
+/// has extract_slice semantics.
+struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
+ using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
+ if (!unpackOp)
+ return failure();
+
+ // Check all offsets are zeros, and all strides are ones.
+ if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
+ !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "expects offsets to be 0s and strides to be 1s");
+ }
+
+ // Create a new empty output tensor.
+ Type elementType = unpackOp.getDestType().getElementType();
+ Value output = rewriter.create<EmptyOp>(
+ sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
+ rewriter.replaceOpWithNewOp<UnPackOp>(
+ sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
+ unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
+ return success();
+ }
+};
+} // namespace
+
+void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
+ patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
+ patterns.getContext());
+}
+
+} // namespace tensor
+} // namespace mlir
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
new file mode 100644
index 0000000000000..0981faf8a1f26
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
+
+func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+ : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+ %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK: func @fold_unpack_slice(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
+// CHECK-SAME: into %[[INIT]]
+// CHECK: return %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+ : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+ %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset(
+// CHECK: %[[UNPACK:.+]] = tensor.unpack
+// CHECK: tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+ %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+ : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+ %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride(
+// CHECK: %[[UNPACK:.+]] = tensor.unpack
+// CHECK: tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %src low[0, 0] high[15, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst : f32
+ } : tensor<16641x16xf32> to tensor<16656x16xf32>
+ %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+ return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK: %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
+// CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]]
+// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
+
+// -----
+
+func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst : f32
+ } : tensor<16641x16xf32> to tensor<16656x16xf32>
+ %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+ return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK: tensor.pad
+// CHECK: tensor.pack
+
+// -----
+
+func.func @pad_pack_
diff erent_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+ %c0 = arith.constant 0 : index
+ %cst0 = arith.constant 0.000000e+00 : f32
+ %cst1 = arith.constant 1.000000e+00 : f32
+ %padded = tensor.pad %src low[0, 0] high[15, 0] {
+ ^bb0(%arg0: index, %arg1: index):
+ tensor.yield %cst0 : f32
+ } : tensor<16641x16xf32> to tensor<16656x16xf32>
+ %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+ %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+ : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+ return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack_
diff erent_padding_value
+// CHECK: tensor.pad
+// CHECK: tensor.pack
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index fed6aecfe839d..a87547035e75b 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -74,6 +74,11 @@ struct TestTensorTransforms
*this, "test-empty-op-folding",
llvm::cl::desc("Test folding of tensor.empty"), llvm::cl::init(false)};
+ Option<bool> testFoldIntoPackAndUnpack{
+ *this, "test-fold-into-pack-and-unpack",
+ llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
+ llvm::cl::init(false)};
+
Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
@@ -95,6 +100,12 @@ static void applyEmptyOpFoldingPatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
+static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
static void applySplitPaddingPatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateSplitPaddingPatterns(patterns);
@@ -276,6 +287,8 @@ void TestTensorTransforms::runOnOperation() {
applyReassociativeReshapeFoldingPatterns(rootOp);
if (testEmptyOpFolding)
applyEmptyOpFoldingPatterns(rootOp);
+ if (testFoldIntoPackAndUnpack)
+ applyFoldIntoPackAndUnpackPatterns(rootOp);
if (testRewriteExtractSliceWithTiledCollapseShape) {
if (failed(
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
More information about the Mlir-commits
mailing list