[Mlir-commits] [mlir] 973dbe2 - [mlir][tensor] Add pattern to fold ExtractSliceOp, PadOp chains.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 11 07:29:56 PDT 2022


Author: gysit
Date: 2022-04-11T14:28:59Z
New Revision: 973dbe20f681ca885edd0f5e63cde62dbdb6c186

URL: https://github.com/llvm/llvm-project/commit/973dbe20f681ca885edd0f5e63cde62dbdb6c186
DIFF: https://github.com/llvm/llvm-project/commit/973dbe20f681ca885edd0f5e63cde62dbdb6c186.diff

LOG: [mlir][tensor] Add pattern to fold ExtractSliceOp, PadOp chains.

The pattern folds chains of tensor::ExtractSliceOp, tensor::PadOp pairs if they pad different dimensions. Repeated tiling and padding of the tiled dimensions may introduce such chains. This canonicalization pattern folds these chains to a single tensor::ExtractSliceOp, tensor::PadOp pair that pads all dimensions at once, which simplifies vectorization and bufferization.

Example:
```mlir
   %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
       : tensor<64x64xf32> to tensor<?x64xf32>
   %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
     } : tensor<?x64xf32> to tensor<8x64xf32>
   %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
        : tensor<8x64xf32> to tensor<8x?xf32>
   %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
     } : tensor<8x?xf32> to tensor<8x4xf32>
```
folds into:
 ```mlir
   %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
        : tensor<64x64xf32> to tensor<?x?xf32>
   %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
     } : tensor<?x?xf32> to tensor<8x4xf32>
 ```

Reviewed By: nicolasvasilache, hanchung

Differential Revision: https://reviews.llvm.org/D122722

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/test/Dialect/Tensor/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 30c74837efa79..660c80865abda 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -678,7 +678,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
     Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>,
     Results<(outs AnyTensor:$result)> {
- 
+
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrName() { return "reassociation"; }
     SmallVector<AffineMap, 4> getReassociationMaps();
@@ -982,6 +982,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
         return getConstantIntValue(ofr) == static_cast<int64_t>(0);
       });
     }
+    /// Return the dimensions with a non-zero low or high padding.
+    llvm::SmallBitVector getPaddedDims();
   }];
 
   let builders = [

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5b52a3fdd24ce..f1181e63ec8f2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1858,6 +1858,18 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
   result.addAttributes(attrs);
 }
 
+llvm::SmallBitVector PadOp::getPaddedDims() {
+  llvm::SmallBitVector paddedDims(getSourceType().getRank());
+  auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
+    for (const auto &en : enumerate(paddingWidths))
+      if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
+        paddedDims.set(en.index());
+  };
+  extractPaddedDims(getMixedLowPad());
+  extractPaddedDims(getMixedHighPad());
+  return paddedDims;
+}
+
 namespace {
 // Folds tensor.pad when padding is static zeros and the attribute
 // doesn't request otherwise.
@@ -1940,13 +1952,169 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
     return success();
   }
 };
+
+/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
+/// 
diff erent dimensions. The pattern applies if the following preconditions
+/// hold:
+///   1) the tensor::ExtractSliceOps are not rank-reducing,
+///   2) the tensor::ExtractSliceOps have only unit-strides,
+///   3) the tensor::PadOps perform only high-padding,
+///   4) the tensor::PadOps have the same constant padding value,
+///   5) the tensor::PadOps do not have common padding dimensions,
+///   6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
+///      zero-offset for every dimension.
+///   7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the
+///      padded source dimensions.
+///
+/// Example:
+///
+/// ```mlir
+///   %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
+///       : tensor<64x64xf32> to tensor<?x64xf32>
+///   %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
+///     } : tensor<?x64xf32> to tensor<8x64xf32>
+///   %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
+///        : tensor<8x64xf32> to tensor<8x?xf32>
+///   %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
+///     } : tensor<8x?xf32> to tensor<8x4xf32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+///   %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
+///        : tensor<64x64xf32> to tensor<?x?xf32>
+///   %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
+///     } : tensor<?x?xf32> to tensor<8x4xf32>
+/// ```
+struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
+  using OpRewritePattern<PadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadOp padOp,
+                                PatternRewriter &rewriter) const override {
+    auto innerSliceOp = padOp.source().getDefiningOp<ExtractSliceOp>();
+    if (!innerSliceOp)
+      return failure();
+    auto outerPadOp = innerSliceOp.source().getDefiningOp<PadOp>();
+    if (!outerPadOp || outerPadOp.nofold())
+      return failure();
+    auto outerSliceOp = outerPadOp.source().getDefiningOp<ExtractSliceOp>();
+    if (!outerSliceOp)
+      return failure();
+
+    // 1) Fail if the chain is rank-reducing.
+    int64_t rank = padOp.getSourceType().getRank();
+    if (outerSliceOp.getSourceType().getRank() != rank) {
+      return rewriter.notifyMatchFailure(padOp,
+                                         "cannot fold rank-reducing chain");
+    }
+
+    // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
+    if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(
+          padOp, "cannot fold non-unit stride ExtractSliceOps");
+    }
+
+    // 3) Fail if the tensor::PadOps have non-zero low padding.
+    if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
+      return rewriter.notifyMatchFailure(padOp,
+                                         "cannot fold PadOps with low padding");
+    }
+
+    // 4) Fail if the tensor::PadOps padding values do not match.
+    Attribute innerAttr, outerAttr;
+    Value innerValue = padOp.getConstantPaddingValue();
+    Value outerValue = outerPadOp.getConstantPaddingValue();
+    if (!innerValue || !outerValue ||
+        !matchPattern(innerValue, m_Constant(&innerAttr)) ||
+        !matchPattern(outerValue, m_Constant(&outerAttr)) ||
+        innerAttr != outerAttr) {
+      return rewriter.notifyMatchFailure(
+          padOp, "cannot fold PadOps with 
diff erent padding values");
+    }
+
+    // 5) Fail if a dimension is padded by both tensor::PadOps.
+    llvm::SmallBitVector innerDims = padOp.getPaddedDims();
+    llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
+    if (innerDims.anyCommon(outerDims)) {
+      return rewriter.notifyMatchFailure(
+          padOp, "cannot fold PadOps with common padding dimensions");
+    }
+
+    // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
+    // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
+    // for every dimension, and use the offset the other pair. Fail if no
+    // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
+    // exists.
+    SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
+    for (auto &en : enumerate(newOffsets)) {
+      OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
+      OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
+      if (!innerDims.test(en.index()) &&
+          (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
+        en.value() = outerOffset;
+        continue;
+      }
+      if (!outerDims.test(en.index()) &&
+          (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
+        en.value() = innerOffset;
+        continue;
+      }
+      return rewriter.notifyMatchFailure(
+          padOp, "cannot find zero-offset and zero-padding pair");
+    }
+
+    // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of
+    // the outer tensor::ExtractSliceOp for the dimensions padded by the outer
+    // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp
+    // does not match the size of the padded dimension. Otherwise, take the size
+    // of the inner tensor::ExtractSliceOp.
+    SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
+    for (auto &en : enumerate(newSizes)) {
+      if (!outerDims.test(en.index()))
+        continue;
+      OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
+      int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
+      assert(!ShapedType::isDynamic(sourceSize) &&
+             "expected padded dimension to have a static size");
+      if (getConstantIntValue(sliceSize) != sourceSize) {
+        return rewriter.notifyMatchFailure(
+            padOp, "cannot fold since the inner ExtractSliceOp size does not "
+                   "match the size of the outer padding");
+      }
+      en.value() = outerSliceOp.getMixedSizes()[en.index()];
+    }
+
+    // Combine the high paddings of the two tensor::PadOps.
+    SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
+    for (auto &en : enumerate(newHighPad)) {
+      if (innerDims.test(en.index()))
+        newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
+      if (outerDims.test(en.index()))
+        newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
+    }
+
+    // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
+    // two paddings in one step.
+    auto newSliceOp = rewriter.create<ExtractSliceOp>(
+        padOp.getLoc(), outerSliceOp.source(), newOffsets, newSizes,
+        innerSliceOp.getMixedStrides());
+    auto newPadOp = rewriter.create<PadOp>(
+        padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
+        padOp.getMixedLowPad(), newHighPad, padOp.nofold());
+    rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
+                                newPadOp.getRegion().begin());
+    rewriter.replaceOp(padOp, newPadOp.getResult());
+    return success();
+  }
+};
+
 } // namespace
 
 void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results
-      .add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast>(
-          context);
+  results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
+              FoldOrthogonalPaddings>(context);
 }
 
 /// Return the padding value of the PadOp if it constant. In this context,

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 493f278763bb7..fef9617e90195 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1252,6 +1252,95 @@ func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tenso
 
 // -----
 
+// CHECK-LABEL: func @fold_orthogonal_pad_chains(
+//  CHECK-SAME:   %[[ARG0:.*]]: tensor<64x64xf32>,
+//  CHECK-SAME:   %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
+func.func @fold_orthogonal_pad_chains(%arg0: tensor<64x64xf32>,
+                                      %sz0 : index, %sz1 : index,
+                                      %pw0 : index, %pw1 : index) -> tensor<8x4xf32> {
+  //       CHECK:   %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+  //  CHECK-SAME:                     [16, 4] [%[[SZ0]], %[[SZ1]]]
+  //       CHECK:   %[[PAD:.*]] = tensor.pad %[[T0]] nofold
+  //  CHECK-SAME:                     high[%[[PW0]], %[[PW1]]]
+  //       CHECK:   return %[[PAD]]
+  %pad_value = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
+  %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } : tensor<?x64xf32> to tensor<8x64xf32>
+  %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
+  %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } : tensor<8x?xf32> to tensor<8x4xf32>
+  func.return %3 : tensor<8x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dont_fold_pad_chains(
+//  CHECK-SAME:   %[[ARG0:.*]]: tensor<64x64xf32>,
+//  CHECK-SAME:   %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
+func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,
+                                %sz0 : index, %sz1 : index,
+                                %pw0 : index, %pw1 : index) -> (tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>) {
+  //       CHECK:   %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+  //       CHECK:   %[[T1:.*]] = tensor.pad %[[T0]]
+  %pad_value = arith.constant 0.0 : f32
+  %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
+  %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } : tensor<?x64xf32> to tensor<8x64xf32>
+
+  // Don't fold if the padding values are 
diff erent.
+  //       CHECK:   %[[T2:.*]] = tensor.extract_slice %[[T1]]
+  //  CHECK-SAME:                     [0, 4] [8, %[[SZ1]]]
+  //       CHECK:   %[[PAD0:.*]] = tensor.pad %[[T2]]
+  %
diff erent_value = arith.constant 1.0 : f32
+  %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
+  %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %
diff erent_value : f32
+    } : tensor<8x?xf32> to tensor<8x4xf32>
+
+  // Don't fold if the pad ops have common padding dimensions.
+  //       CHECK:   %[[T3:.*]] = tensor.extract_slice %[[T1]]
+  //  CHECK-SAME:                     [4, 0] [%[[SZ1]], 64]
+  //       CHECK:   %[[PAD1:.*]] = tensor.pad %[[T3]]
+  %4 = tensor.extract_slice %1[4, 0] [%sz1, 64] [1, 1] : tensor<8x64xf32> to tensor<?x64xf32>
+  %5 = tensor.pad %4 nofold low[0, 0] high[%pw1, 0] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } : tensor<?x64xf32> to tensor<4x64xf32>
+
+  // Don't fold if padded source tensor dimension is accessed at an offset.
+  //       CHECK:   %[[T4:.*]] = tensor.extract_slice %[[T1]]
+  //  CHECK-SAME:                     [%[[SZ0]], 4] [8, %[[SZ1]]
+  //       CHECK:   %[[PAD2:.*]] = tensor.pad %[[T4]]
+  %6 = tensor.extract_slice %1[%sz0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
+  %7 = tensor.pad %6 nofold low[0, 0] high[0, %pw1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } : tensor<8x?xf32> to tensor<8x4xf32>
+
+  // Don't fold if a padded source tensor dimension is sliced.
+  //       CHECK:   %[[T5:.*]] = tensor.extract_slice %[[T1]]
+  //  CHECK-SAME:                     [0, 4] [6, %[[SZ1]]
+  //       CHECK:   %[[PAD3:.*]] = tensor.pad %[[T5]]
+  %8 = tensor.extract_slice %1[0, 4] [6, %sz1] [1, 1] : tensor<8x64xf32> to tensor<6x?xf32>
+  %9 = tensor.pad %8 nofold low[0, 0] high[0, %pw1] {
+    ^bb0(%arg1: index, %arg2: index):
+      tensor.yield %pad_value : f32
+    } : tensor<6x?xf32> to tensor<6x4xf32>
+
+  //       CHECK:   return %[[PAD0]], %[[PAD1]], %[[PAD2]], %[[PAD3]]
+  func.return %3, %5, %7, %9 : tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_collapse_shape_from_elements
 func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
   // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>


        


More information about the Mlir-commits mailing list