[Mlir-commits] [mlir] 30d542f - [MLIR][Tensor] Introduce a pattern to propagate through `tensor.pad`

Lorenzo Chelini llvmlistbot at llvm.org
Tue Feb 14 23:49:01 PST 2023


Author: Lorenzo Chelini
Date: 2023-02-15T08:48:55+01:00
New Revision: 30d542f9b2adf10ee2cb7e07877ad7f0cfdbfea2

URL: https://github.com/llvm/llvm-project/commit/30d542f9b2adf10ee2cb7e07877ad7f0cfdbfea2
DIFF: https://github.com/llvm/llvm-project/commit/30d542f9b2adf10ee2cb7e07877ad7f0cfdbfea2.diff

LOG: [MLIR][Tensor] Introduce a pattern to propagate through `tensor.pad`

Introduce a pattern to 'push down' a `tensor.unpack` through a
`tensor.pad`. The propagation happens if the unpack does not touch the
padded dimensions.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D143907

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
    mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 1b6d1d247a2d..bf5e64ba1f34 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -465,10 +465,69 @@ struct PushDownUnPackOpThroughElemGenericOp
   }
 };
 
+/// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
+/// add as many zero padding dimensions in `high` and `low` based on the number
+/// of point loops.
+struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
+  using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::UnPackOp unpackOp =
+        padOp.getSource().getDefiningOp<tensor::UnPackOp>();
+    if (!unpackOp)
+      return failure();
+
+    Location loc = padOp.getLoc();
+    // Bail out if one of the padded dimension is a tiled one.
+    llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
+    ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
+    llvm::SmallBitVector innerDims(paddedDims.size());
+    for (int64_t dim : innerDimsPos)
+      innerDims.flip(dim);
+    if (paddedDims.anyCommon(innerDims))
+      return failure();
+
+    Value paddingVal = padOp.getConstantPaddingValue();
+    if (!paddingVal)
+      return failure();
+
+    // If we have `outer_dims_perms` we need to adjust the padded dimensions.
+    ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
+    SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
+    if (!outerDimsPerm.empty()) {
+      applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
+      applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
+    }
+    // Add zero padding for the point loops.
+    size_t pointLoopsSize = innerDimsPos.size();
+    lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+    highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
+
+    auto newPadOp = rewriter.create<tensor::PadOp>(
+        loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
+        paddingVal, padOp.getNofold());
+
+    // Inject the tensor.unpack right after the packed padOp.
+    Value outputUnPack = rewriter.create<tensor::EmptyOp>(
+        loc, padOp.getResultType().getShape(),
+        padOp.getResultType().getElementType());
+
+    Value replacement = rewriter.create<tensor::UnPackOp>(
+        loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
+        unpackOp.getMixedTiles(), outerDimsPerm);
+    rewriter.replaceOp(padOp, replacement);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns) {
-  patterns.insert<BubbleUpPackOpThroughElemGenericOpPattern,
-                  PushDownUnPackOpThroughElemGenericOp>(patterns.getContext());
+  patterns
+      .insert<BubbleUpPackOpThroughElemGenericOpPattern,
+              PushDownUnPackOpThroughElemGenericOp, PushDownUnPackThroughPadOp>(
+          patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index b699b3d19cee..32190cac8d2f 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -471,4 +471,70 @@ func.func @forward_tensor_empty(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x5
 // CHECK-SAME:  outs(%[[DEST]]
 // CHECK: %[[UNPACKED:.+]] = tensor.unpack %[[RES]]
 // CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
-// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]] 
+// CHECK-SAME:  into %[[ARG0_UNPACK_EMPTY]]
+
+// -----
+
+func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x64xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+  %padded = tensor.pad %1 low[0, 1, 1, 0] high[0, 1, 1, 0] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x56x56x64xf32> to tensor<1x58x58x64xf32>
+  return %padded : tensor<1x58x58x64xf32>
+}
+
+// CHECK: func.func @pad_valid_propagation(
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[0, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x58x58x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[EMPTY]] : tensor<1x2x58x58x32xf32> -> tensor<1x58x58x64xf32>
+
+// -----
+
+func.func @pad_valid_propagation(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<2x58x58x64xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+  %padded = tensor.pad %1 low[1, 1, 1, 0] high[0, 1, 1, 0] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x56x56x64xf32> to tensor<2x58x58x64xf32>
+  return %padded : tensor<2x58x58x64xf32>
+}
+
+// CHECK: func.func @pad_valid_propagation(
+// CHECK-SAME:  %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[ARG0]] low[1, 0, 1, 1, 0] high[0, 0, 1, 1, 0]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<2x58x58x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PADDED]]
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
+// CHECK-SAME:  into %[[EMPTY]] : tensor<2x2x58x58x32xf32> -> tensor<2x58x58x64xf32>
+
+// -----
+
+func.func @pad_along_unpacked_dim(%arg0: tensor<1x2x56x56x32xf32>) -> tensor<1x58x58x66xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<1x56x56x64xf32>
+  %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+  %padded = tensor.pad %1 low[0, 1, 1, 1] high[0, 1, 1, 1] {
+    ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<1x56x56x64xf32> to tensor<1x58x58x66xf32>
+  return %padded : tensor<1x58x58x66xf32>
+}
+
+// CHECK: func.func @pad_along_unpacked_dim(
+// CHECK: %[[ARG0:.+]]: tensor<1x2x56x56x32xf32>)
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x56x56x64xf32>
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]] 
+// CHECK-SAME:  outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:  into %[[EMPTY]] : tensor<1x2x56x56x32xf32> -> tensor<1x56x56x64xf32>
+// CHECK: %[[PADDED:.+]] = tensor.pad %[[UNPACK]] low[0, 1, 1, 1] high[0, 1, 1, 1]


        


More information about the Mlir-commits mailing list