[Mlir-commits] [mlir] e789efc - [mlir][linalg] Refactor PadTensorOpVectorizationPattern (NFC)

Matthias Springer llvmlistbot at llvm.org
Fri Jun 4 07:51:20 PDT 2021


Author: Matthias Springer
Date: 2021-06-04T23:45:08+09:00
New Revision: e789efc92a5aa6495a22bd3a93a03f640dc2f32a

URL: https://github.com/llvm/llvm-project/commit/e789efc92a5aa6495a22bd3a93a03f640dc2f32a
DIFF: https://github.com/llvm/llvm-project/commit/e789efc92a5aa6495a22bd3a93a03f640dc2f32a.diff

LOG: [mlir][linalg] Refactor PadTensorOpVectorizationPattern (NFC)

* Rename PadTensorOpVectorizationPattern to GenericPadTensorOpVectorizationPattern.
* Make GenericPadTensorOpVectorizationPattern a private pattern, to be instantiated via populatePadTensorOpVectorizationPatterns.
* Factor out parts of PadTensorOpVectorizationPattern into helper functions.

This commit prepares PadTensorOpVectorizationPattern for a series of subsequent commits that add more specialized PadTensorOp vectorization patterns.

Differential Revision: https://reviews.llvm.org/D103681

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 25f4b3627d6ef..9df8fbb2e4693 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -880,14 +880,14 @@ struct PadTensorOpTransformationPattern : public OpRewritePattern<PadTensorOp> {
                                 PatternRewriter &rewriter) const override;
 };
 
-/// PadTensorOp does not implement the LinalgStructuredOpInterface `LinalgOp`,
-/// it needs a specific pattern to vectorize.
-struct PadTensorOpVectorizationPattern : public OpRewritePattern<PadTensorOp> {
-  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(PadTensorOp padOp,
-                                PatternRewriter &rewriter) const override;
-};
+/// Populates `patterns` with patterns that vectorize linalg.pad_tensor.
+/// These patterns are meant to apply in a complementary fashion. Benefits
+/// are used to encode a certain ordering of pattern application. To avoid
+/// scattering magic constants throughout the code base, the patterns must be
+/// added with this function. `baseBenefit` can be used to offset the benefit
+/// of all PadTensorOp vectorization patterns by a certain value.
+void populatePadTensorOpVectorizationPatterns(
+    RewritePatternSet &patterns, PatternBenefit baseBenefit = 1);
 
 /// Match and rewrite for the pattern:
 /// ```

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 12a8d80c72fcc..b52059d535cf1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -650,66 +650,81 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
 // Misc. vectorization patterns.
 //----------------------------------------------------------------------------//
 
-/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
-/// TransferWriteOp. For now, this only applies when all low and high paddings
-/// are determined to be zero.
-LogicalResult PadTensorOpVectorizationPattern::matchAndRewrite(
-    linalg::PadTensorOp padOp, PatternRewriter &rewriter) const {
-  // Helper function to determine whether an OpFoldResult is not a zero Index.
-  auto isNotZeroIndex = [](OpFoldResult ofr) {
-    if (Attribute attr = ofr.dyn_cast<Attribute>())
-      return attr.cast<IntegerAttr>().getInt() != 0;
-    Value v = ofr.get<Value>();
-    if (auto constOp = v.getDefiningOp<ConstantOp>())
-      if (auto intAttr = constOp.getValue().dyn_cast<IntegerAttr>())
-        return intAttr.getValue().getSExtValue() != 0;
-    return true;
-  };
-
-  auto resultShapedType = padOp.result().getType().cast<ShapedType>();
-  // Bail on non-static shapes.
-  if (!resultShapedType.hasStaticShape())
-    return failure();
-
-  // If any pad_low is not a static 0, needs a mask. Bail for now.
-  if (llvm::any_of(padOp.getMixedLowPad(), isNotZeroIndex))
-    return failure();
-  VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
-  if (!vectorType)
-    return failure();
-
-  // Only support padding with a constant for now, i.e. either:
-  //   1. A BBarg from a 
diff erent block.
-  //   2. A value defined outside of the current block.
-  Block &block = padOp.region().front();
+/// Given a block, return the Value that the block yields if that Value is
+/// constant. In this context, "constant" means "defined outside of the block".
+/// Should not be called on blocks that yield more than one value.
+///
+/// Values are considered constant in two cases:
+///  - A basic block argument from a 
diff erent block.
+///  - A value defined outside of the block.
+///
+/// If the yielded value is not constant, an empty Value is returned.
+static Value getConstantYieldValueFromBlock(Block &block) {
   auto yieldOp = cast<YieldOp>(block.getTerminator());
   assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
-  Value padValue = yieldOp.values().front();
-  Operation *definingOp = padValue.getDefiningOp();
+  Value result = yieldOp.values().front();
+  Operation *definingOp = result.getDefiningOp();
+
+  // Check if yield value is defined inside the block.
   if (definingOp && definingOp->getBlock() == &block)
-    return failure();
-  if (!definingOp && padValue.cast<BlockArgument>().getOwner() == &block)
-    return failure();
+    return Value();
+  // Check if the yield value is a BB arg of the block.
+  if (!definingOp && result.cast<BlockArgument>().getOwner() == &block)
+    return Value();
 
-  // TODO: if any pad_high is not a static 0, needs a mask. For now, just bail.
-  if (llvm::any_of(padOp.getMixedHighPad(),
-                   [&](OpFoldResult ofr) { return isNotZeroIndex(ofr); }))
-    return failure();
+  return result;
+}
+
+/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
+/// TransferWriteOp. For now, this only applies when all low and high paddings
+/// are determined to be zero.
+struct GenericPadTensorOpVectorizationPattern
+    : public OpRewritePattern<PadTensorOp> {
+  using OpRewritePattern<PadTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PadTensorOp padOp,
+                                PatternRewriter &rewriter) const override {
+    /// Given an OpFoldResult, return true if its value is guaranteed to be a
+    /// zero integer.
+    auto isZeroInt = [&](OpFoldResult ofr) {
+      return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); };
+    // Low padding must be static 0.
+    if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) return failure();
+    // High padding must be static 0.
+    if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure();
+    // Pad value must be a constant.
+    auto padValue = getConstantYieldValueFromBlock(padOp.region().front());
+    if (!padValue) return failure();
+
+    // Bail on non-static shapes.
+    auto resultShapedType = padOp.result().getType().cast<ShapedType>();
+    if (!resultShapedType.hasStaticShape())
+      return failure();
+    VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
+    if (!vectorType)
+      return failure();
 
-  // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
-  // TransferWriteOp@[0..0].
-  SmallVector<Value> indices(
-      resultShapedType.getRank(),
-      rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
-  Value read = rewriter.create<vector::TransferReadOp>(
-      padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
-  Value init =
-      rewriter.create<InitTensorOp>(padOp.getLoc(), resultShapedType.getShape(),
-                                    resultShapedType.getElementType());
-  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
-                                                       indices);
+    // Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
+    // TransferWriteOp@[0..0].
+    SmallVector<Value> indices(
+        resultShapedType.getRank(),
+        rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
+    Value read = rewriter.create<vector::TransferReadOp>(
+        padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
+    Value init = rewriter.create<InitTensorOp>(
+        padOp.getLoc(), resultShapedType.getShape(),
+        resultShapedType.getElementType());
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
+                                                         indices);
 
-  return success();
+    return success();
+  }
+};
+
+void mlir::linalg::populatePadTensorOpVectorizationPatterns(
+    RewritePatternSet &patterns, PatternBenefit baseBenefit) {
+  patterns.add<GenericPadTensorOpVectorizationPattern>(
+      patterns.getContext(), baseBenefit);
 }
 
 // TODO: cleanup all the convolution vectorization patterns.

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 181e1d93b250a..16eb79ffe4fb6 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -508,7 +508,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
       funcOp.getContext(),
       LinalgTransformationFilter()
           .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
-  patterns.add<PadTensorOpVectorizationPattern>(funcOp.getContext());
+  populatePadTensorOpVectorizationPatterns(patterns);
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
 }
 


        


More information about the Mlir-commits mailing list