[Mlir-commits] [mlir] 6538808 - [mlir][tensor] Add patterns that fold ops into pack and unpack ops.

Hanhan Wang llvmlistbot at llvm.org
Wed Jan 11 13:52:02 PST 2023


Author: Hanhan Wang
Date: 2023-01-11T13:51:49-08:00
New Revision: 65388086e68245c11c5acf5cd6b3570d8e4d11bf

URL: https://github.com/llvm/llvm-project/commit/65388086e68245c11c5acf5cd6b3570d8e4d11bf
DIFF: https://github.com/llvm/llvm-project/commit/65388086e68245c11c5acf5cd6b3570d8e4d11bf.diff

LOG: [mlir][tensor] Add patterns that fold ops into pack and unpack ops.

The tensor.pack ops have pad semantic, so we can fold pad + pack into
pack when

1. They have the same padding values or the pack op does not have
   padding values.
2. The pad op does not have low paddings.

The tensor.unpack ops have extract_slice semantic, so we can fold unpack
+ extract_slice into unpack when

1. All the offsets are 0s.
2. All the strides are 1s.

Reviewed By: tyb0807

Differential Revision: https://reviews.llvm.org/D141099

Added: 
    mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
    mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Modified: 
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
    mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 430842b4584ff..01985c943527c 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -44,6 +44,11 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
 /// tensor.[extract_slice|cast|expand_shape|collapse_shape].
 void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that fold operations like `tensor.pad`
+/// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
+/// respectively.
+void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
+
 } // namespace tensor
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index 216fc8edb9258..5ed3d97b2719f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   Bufferize.cpp
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
+  FoldIntoPackAndUnpackPatterns.cpp
   MergeConsecutiveInsertExtractSlicePatterns.cpp
   ReshapePatterns.cpp
   SplitPaddingPatterns.cpp

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
new file mode 100644
index 0000000000000..744e49edcb6c9
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -0,0 +1,87 @@
+//===- FoldIntoPackAndUnpackPatterns.cpp ----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace tensor {
+namespace {
+
+static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
+  return llvm::all_of(
+      ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
+}
+
+/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
+/// the pad op has zero low paddings, or if `pack` has no padding values.
+struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto padOp = packOp.getSource().getDefiningOp<PadOp>();
+
+    if (!padOp || padOp.getNofold() || !padOp.hasZeroLowPad())
+      return failure();
+
+    Value constantPaddingValue = padOp.getConstantPaddingValue();
+    if (!constantPaddingValue)
+      return failure();
+
+    if (auto paddingValue = packOp.getPaddingValue())
+      if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
+        return failure();
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
+        packOp.getMixedTiles(), constantPaddingValue,
+        packOp.getOuterDimsPerm());
+    return success();
+  }
+};
+
+/// Fold a `unpack` -> `extract_slice` into the `unpack` since it already
+/// has extract_slice semantics.
+struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
+  using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto unpackOp = sliceOp.getSource().getDefiningOp<UnPackOp>();
+    if (!unpackOp)
+      return failure();
+
+    // Check all offsets are zeros, and all strides are ones.
+    if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
+        !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "expects offsets to be 0s and strides to be 1s");
+    }
+
+    // Create a new empty output tensor.
+    Type elementType = unpackOp.getDestType().getElementType();
+    Value output = rewriter.create<EmptyOp>(
+        sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType);
+    rewriter.replaceOpWithNewOp<UnPackOp>(
+        sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(),
+        unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm());
+    return success();
+  }
+};
+} // namespace
+
+void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
+  patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
+      patterns.getContext());
+}
+
+} // namespace tensor
+} // namespace mlir

diff  --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
new file mode 100644
index 0000000000000..0981faf8a1f26
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-into-pack-and-unpack  %s | FileCheck %s
+
+func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+//      CHECK: func @fold_unpack_slice(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x8x4xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
+//      CHECK:   %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
+//      CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_non_zero_offset(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, %arg4] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_zero_offset(
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
+//       CHECK:   tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @nofold_unpack_slice_non_unit_stride(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
+    %arg2 : index, %arg3 : index, %arg4 : index) -> tensor<?x?xf32> {
+  %0 = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
+      : tensor<?x?x8x4xf32> -> tensor<?x?xf32>
+  %1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [%arg4, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @nofold_unpack_slice_non_unit_stride(
+//       CHECK:   %[[UNPACK:.+]] = tensor.unpack
+//       CHECK:   tensor.extract_slice %[[UNPACK]]
+
+// -----
+
+func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %src low[0, 0] high[15, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : f32
+  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]
+// CHECK:         %[[PAD_VAL:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK:         %[[DEST:.+]] = tensor.empty() : tensor<2082x1x8x32xf32>
+// CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]]
+// CHECK-SAME:      padding_value(%[[PAD_VAL]] : f32)
+// CHECK-SAME:      inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %[[DEST]]
+
+// -----
+
+func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst : f32
+  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = tensor.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @nofold_pad_pack
+// CHECK:         tensor.pad
+// CHECK:         tensor.pack
+
+// -----
+
+func.func @pad_pack_
diff erent_padding_value(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
+  %c0 = arith.constant 0 : index
+  %cst0 = arith.constant 0.000000e+00 : f32
+  %cst1 = arith.constant 1.000000e+00 : f32
+  %padded = tensor.pad %src low[0, 0] high[15, 0] {
+  ^bb0(%arg0: index, %arg1: index):
+    tensor.yield %cst0 : f32
+  } : tensor<16641x16xf32> to tensor<16656x16xf32>
+  %empty = tensor.empty() : tensor<2082x1x8x32xf32>
+  %pack = tensor.pack %padded padding_value(%cst1 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
+      : tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
+  return %pack : tensor<2082x1x8x32xf32>
+}
+// CHECK-LABEL: func.func @pad_pack_
diff erent_padding_value
+// CHECK:         tensor.pad
+// CHECK:         tensor.pack

diff  --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index fed6aecfe839d..a87547035e75b 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -74,6 +74,11 @@ struct TestTensorTransforms
       *this, "test-empty-op-folding",
       llvm::cl::desc("Test folding of tensor.empty"), llvm::cl::init(false)};
 
+  Option<bool> testFoldIntoPackAndUnpack{
+      *this, "test-fold-into-pack-and-unpack",
+      llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
+      llvm::cl::init(false)};
+
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -95,6 +100,12 @@ static void applyEmptyOpFoldingPatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 static void applySplitPaddingPatterns(Operation *rootOp) {
   RewritePatternSet patterns(rootOp->getContext());
   tensor::populateSplitPaddingPatterns(patterns);
@@ -276,6 +287,8 @@ void TestTensorTransforms::runOnOperation() {
     applyReassociativeReshapeFoldingPatterns(rootOp);
   if (testEmptyOpFolding)
     applyEmptyOpFoldingPatterns(rootOp);
+  if (testFoldIntoPackAndUnpack)
+    applyFoldIntoPackAndUnpackPatterns(rootOp);
   if (testRewriteExtractSliceWithTiledCollapseShape) {
     if (failed(
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))


        


More information about the Mlir-commits mailing list