[Mlir-commits] [mlir] 973dbe2 - [mlir][tensor] Add pattern to fold ExtractSliceOp, PadOp chains.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 11 07:29:56 PDT 2022
Author: gysit
Date: 2022-04-11T14:28:59Z
New Revision: 973dbe20f681ca885edd0f5e63cde62dbdb6c186
URL: https://github.com/llvm/llvm-project/commit/973dbe20f681ca885edd0f5e63cde62dbdb6c186
DIFF: https://github.com/llvm/llvm-project/commit/973dbe20f681ca885edd0f5e63cde62dbdb6c186.diff
LOG: [mlir][tensor] Add pattern to fold ExtractSliceOp, PadOp chains.
The pattern folds chains of tensor::ExtractSliceOp, tensor::PadOp pairs if they pad different dimensions. Repeated tiling and padding of the tiled dimensions may introduce such chains. This canonicalization pattern folds these chains to a single tensor::ExtractSliceOp, tensor::PadOp pair that pads all dimensions at once, which simplifies vectorization and bufferization.
Example:
```mlir
%0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
: tensor<64x64xf32> to tensor<?x64xf32>
%1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
} : tensor<?x64xf32> to tensor<8x64xf32>
%2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
: tensor<8x64xf32> to tensor<8x?xf32>
%res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
} : tensor<8x?xf32> to tensor<8x4xf32>
```
folds into:
```mlir
%0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
: tensor<64x64xf32> to tensor<?x?xf32>
%res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
} : tensor<?x?xf32> to tensor<8x4xf32>
```
Reviewed By: nicolasvasilache, hanchung
Differential Revision: https://reviews.llvm.org/D122722
Added:
Modified:
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 30c74837efa79..660c80865abda 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -678,7 +678,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
Tensor_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyTensor:$result)> {
-
+
code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrName() { return "reassociation"; }
SmallVector<AffineMap, 4> getReassociationMaps();
@@ -982,6 +982,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
return getConstantIntValue(ofr) == static_cast<int64_t>(0);
});
}
+ /// Return the dimensions with a non-zero low or high padding.
+ llvm::SmallBitVector getPaddedDims();
}];
let builders = [
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 5b52a3fdd24ce..f1181e63ec8f2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1858,6 +1858,18 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
result.addAttributes(attrs);
}
+llvm::SmallBitVector PadOp::getPaddedDims() {
+ llvm::SmallBitVector paddedDims(getSourceType().getRank());
+ auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
+ for (const auto &en : enumerate(paddingWidths))
+ if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
+ paddedDims.set(en.index());
+ };
+ extractPaddedDims(getMixedLowPad());
+ extractPaddedDims(getMixedHighPad());
+ return paddedDims;
+}
+
namespace {
// Folds tensor.pad when padding is static zeros and the attribute
// doesn't request otherwise.
@@ -1940,13 +1952,169 @@ struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
return success();
}
};
+
+/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
+///
diff erent dimensions. The pattern applies if the following preconditions
+/// hold:
+/// 1) the tensor::ExtractSliceOps are not rank-reducing,
+/// 2) the tensor::ExtractSliceOps have only unit-strides,
+/// 3) the tensor::PadOps perform only high-padding,
+/// 4) the tensor::PadOps have the same constant padding value,
+/// 5) the tensor::PadOps do not have common padding dimensions,
+/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
+/// zero-offset for every dimension.
+/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for the
+/// padded source dimensions.
+///
+/// Example:
+///
+/// ```mlir
+/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
+/// : tensor<64x64xf32> to tensor<?x64xf32>
+/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
+/// } : tensor<?x64xf32> to tensor<8x64xf32>
+/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
+/// : tensor<8x64xf32> to tensor<8x?xf32>
+/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
+/// } : tensor<8x?xf32> to tensor<8x4xf32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
+/// : tensor<64x64xf32> to tensor<?x?xf32>
+/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
+/// } : tensor<?x?xf32> to tensor<8x4xf32>
+/// ```
+struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
+ using OpRewritePattern<PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ auto innerSliceOp = padOp.source().getDefiningOp<ExtractSliceOp>();
+ if (!innerSliceOp)
+ return failure();
+ auto outerPadOp = innerSliceOp.source().getDefiningOp<PadOp>();
+ if (!outerPadOp || outerPadOp.nofold())
+ return failure();
+ auto outerSliceOp = outerPadOp.source().getDefiningOp<ExtractSliceOp>();
+ if (!outerSliceOp)
+ return failure();
+
+ // 1) Fail if the chain is rank-reducing.
+ int64_t rank = padOp.getSourceType().getRank();
+ if (outerSliceOp.getSourceType().getRank() != rank) {
+ return rewriter.notifyMatchFailure(padOp,
+ "cannot fold rank-reducing chain");
+ }
+
+ // 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
+ if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(
+ padOp, "cannot fold non-unit stride ExtractSliceOps");
+ }
+
+ // 3) Fail if the tensor::PadOps have non-zero low padding.
+ if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
+ return rewriter.notifyMatchFailure(padOp,
+ "cannot fold PadOps with low padding");
+ }
+
+ // 4) Fail if the tensor::PadOps padding values do not match.
+ Attribute innerAttr, outerAttr;
+ Value innerValue = padOp.getConstantPaddingValue();
+ Value outerValue = outerPadOp.getConstantPaddingValue();
+ if (!innerValue || !outerValue ||
+ !matchPattern(innerValue, m_Constant(&innerAttr)) ||
+ !matchPattern(outerValue, m_Constant(&outerAttr)) ||
+ innerAttr != outerAttr) {
+ return rewriter.notifyMatchFailure(
+ padOp, "cannot fold PadOps with
diff erent padding values");
+ }
+
+ // 5) Fail if a dimension is padded by both tensor::PadOps.
+ llvm::SmallBitVector innerDims = padOp.getPaddedDims();
+ llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
+ if (innerDims.anyCommon(outerDims)) {
+ return rewriter.notifyMatchFailure(
+ padOp, "cannot fold PadOps with common padding dimensions");
+ }
+
+ // 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
+ // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
+ // for every dimension, and use the offset the other pair. Fail if no
+ // zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
+ // exists.
+ SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
+ for (auto &en : enumerate(newOffsets)) {
+ OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
+ OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
+ if (!innerDims.test(en.index()) &&
+ (getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
+ en.value() = outerOffset;
+ continue;
+ }
+ if (!outerDims.test(en.index()) &&
+ (getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
+ en.value() = innerOffset;
+ continue;
+ }
+ return rewriter.notifyMatchFailure(
+ padOp, "cannot find zero-offset and zero-padding pair");
+ }
+
+ // 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size of
+ // the outer tensor::ExtractSliceOp for the dimensions padded by the outer
+ // tensor::PadOp and fail if the size of the inner tensor::ExtractSliceOp
+ // does not match the size of the padded dimension. Otherwise, take the size
+ // of the inner tensor::ExtractSliceOp.
+ SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
+ for (auto &en : enumerate(newSizes)) {
+ if (!outerDims.test(en.index()))
+ continue;
+ OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
+ int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
+ assert(!ShapedType::isDynamic(sourceSize) &&
+ "expected padded dimension to have a static size");
+ if (getConstantIntValue(sliceSize) != sourceSize) {
+ return rewriter.notifyMatchFailure(
+ padOp, "cannot fold since the inner ExtractSliceOp size does not "
+ "match the size of the outer padding");
+ }
+ en.value() = outerSliceOp.getMixedSizes()[en.index()];
+ }
+
+ // Combine the high paddings of the two tensor::PadOps.
+ SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
+ for (auto &en : enumerate(newHighPad)) {
+ if (innerDims.test(en.index()))
+ newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
+ if (outerDims.test(en.index()))
+ newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
+ }
+
+ // Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs the
+ // two paddings in one step.
+ auto newSliceOp = rewriter.create<ExtractSliceOp>(
+ padOp.getLoc(), outerSliceOp.source(), newOffsets, newSizes,
+ innerSliceOp.getMixedStrides());
+ auto newPadOp = rewriter.create<PadOp>(
+ padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
+ padOp.getMixedLowPad(), newHighPad, padOp.nofold());
+ rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
+ newPadOp.getRegion().begin());
+ rewriter.replaceOp(padOp, newPadOp.getResult());
+ return success();
+ }
+};
+
} // namespace
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast>(
- context);
+ results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
+ FoldOrthogonalPaddings>(context);
}
/// Return the padding value of the PadOp if it constant. In this context,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 493f278763bb7..fef9617e90195 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1252,6 +1252,95 @@ func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tenso
// -----
+// CHECK-LABEL: func @fold_orthogonal_pad_chains(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
+func.func @fold_orthogonal_pad_chains(%arg0: tensor<64x64xf32>,
+ %sz0 : index, %sz1 : index,
+ %pw0 : index, %pw1 : index) -> tensor<8x4xf32> {
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK-SAME: [16, 4] [%[[SZ0]], %[[SZ1]]]
+ // CHECK: %[[PAD:.*]] = tensor.pad %[[T0]] nofold
+ // CHECK-SAME: high[%[[PW0]], %[[PW1]]]
+ // CHECK: return %[[PAD]]
+ %pad_value = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
+ %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %pad_value : f32
+ } : tensor<?x64xf32> to tensor<8x64xf32>
+ %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
+ %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %pad_value : f32
+ } : tensor<8x?xf32> to tensor<8x4xf32>
+ func.return %3 : tensor<8x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dont_fold_pad_chains(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
+func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,
+ %sz0 : index, %sz1 : index,
+ %pw0 : index, %pw1 : index) -> (tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>) {
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
+ // CHECK: %[[T1:.*]] = tensor.pad %[[T0]]
+ %pad_value = arith.constant 0.0 : f32
+ %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
+ %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %pad_value : f32
+ } : tensor<?x64xf32> to tensor<8x64xf32>
+
+ // Don't fold if the padding values are
diff erent.
+ // CHECK: %[[T2:.*]] = tensor.extract_slice %[[T1]]
+ // CHECK-SAME: [0, 4] [8, %[[SZ1]]]
+ // CHECK: %[[PAD0:.*]] = tensor.pad %[[T2]]
+ %
diff erent_value = arith.constant 1.0 : f32
+ %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
+ %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %
diff erent_value : f32
+ } : tensor<8x?xf32> to tensor<8x4xf32>
+
+ // Don't fold if the pad ops have common padding dimensions.
+ // CHECK: %[[T3:.*]] = tensor.extract_slice %[[T1]]
+ // CHECK-SAME: [4, 0] [%[[SZ1]], 64]
+ // CHECK: %[[PAD1:.*]] = tensor.pad %[[T3]]
+ %4 = tensor.extract_slice %1[4, 0] [%sz1, 64] [1, 1] : tensor<8x64xf32> to tensor<?x64xf32>
+ %5 = tensor.pad %4 nofold low[0, 0] high[%pw1, 0] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %pad_value : f32
+ } : tensor<?x64xf32> to tensor<4x64xf32>
+
+ // Don't fold if padded source tensor dimension is accessed at an offset.
+ // CHECK: %[[T4:.*]] = tensor.extract_slice %[[T1]]
+ // CHECK-SAME: [%[[SZ0]], 4] [8, %[[SZ1]]
+ // CHECK: %[[PAD2:.*]] = tensor.pad %[[T4]]
+ %6 = tensor.extract_slice %1[%sz0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
+ %7 = tensor.pad %6 nofold low[0, 0] high[0, %pw1] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %pad_value : f32
+ } : tensor<8x?xf32> to tensor<8x4xf32>
+
+ // Don't fold if a padded source tensor dimension is sliced.
+ // CHECK: %[[T5:.*]] = tensor.extract_slice %[[T1]]
+ // CHECK-SAME: [0, 4] [6, %[[SZ1]]
+ // CHECK: %[[PAD3:.*]] = tensor.pad %[[T5]]
+ %8 = tensor.extract_slice %1[0, 4] [6, %sz1] [1, 1] : tensor<8x64xf32> to tensor<6x?xf32>
+ %9 = tensor.pad %8 nofold low[0, 0] high[0, %pw1] {
+ ^bb0(%arg1: index, %arg2: index):
+ tensor.yield %pad_value : f32
+ } : tensor<6x?xf32> to tensor<6x4xf32>
+
+ // CHECK: return %[[PAD0]], %[[PAD1]], %[[PAD2]], %[[PAD3]]
+ func.return %3, %5, %7, %9 : tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_collapse_shape_from_elements
func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
// CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>
More information about the Mlir-commits
mailing list