[Mlir-commits] [mlir] cb1a423 - [mlir][vector] Move splitting transfer ops into a separate entry point

Lei Zhang llvmlistbot at llvm.org
Tue Feb 16 07:04:49 PST 2021


Author: Lei Zhang
Date: 2021-02-16T10:04:34-05:00
New Revision: cb1a42359bff2ba49d072df88ad3ffb4c66c16d8

URL: https://github.com/llvm/llvm-project/commit/cb1a42359bff2ba49d072df88ad3ffb4c66c16d8
DIFF: https://github.com/llvm/llvm-project/commit/cb1a42359bff2ba49d072df88ad3ffb4c66c16d8.diff

LOG: [mlir][vector] Move splitting transfer ops into a separate entry point

These patterns unrolls transfer read/write ops if the vector consumers/
producers are extract/insert slices op. Transfer ops can map to hardware
load/store functionalities, where the vector size matters for bandwidth
considerations. So these patterns should be collected separately, instead
of being generic canonicalization patterns.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D96782

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.h
    mlir/lib/Dialect/Vector/VectorTransforms.cpp
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 3650470bc7be..ee7ed62dcf01 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -45,6 +45,18 @@ void populateVectorToVectorCanonicalizationPatterns(
 void populateVectorToVectorTransformationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context);
 
+/// Collect a set of patterns to split transfer read/write ops.
+///
+/// These patterns unrolls transfer read/write ops if the vector consumers/
+/// producers are extract/insert slices op. Transfer ops can map to hardware
+/// load/store functionalities, where the vector size matters for bandwith
+/// considerations. So these patterns should be collected separately, instead
+/// of being generic canonicalization patterns. Also one can let the
+/// `ignoreFilter` to return true to fail matching for fine-grained control.
+void populateSplitVectorTransferPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context,
+    std::function<bool(Operation *)> ignoreFilter = nullptr);
+
 /// Collect a set of leading one dimension removal patterns.
 ///
 /// These patterns insert vector.shape_cast to remove leading one dimensions

diff  --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index dd3ea8ead746..0a6c88d4d99b 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -705,22 +705,33 @@ mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op,
 
 namespace {
 
-// Splits vector TransferReadOp into smaller TransferReadOps based on slicing
-// scheme of its unique ExtractSlicesOp user.
-struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+// Splits a TransferReadOp into smaller TransferReadOps based on slicing
+// scheme of its unique ExtractSlicesOp users.
+class SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
+public:
+  SplitTransferReadOp(MLIRContext *context,
+                      std::function<bool(Operation *)> ignoreFilter = nullptr,
+                      PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), ignoreFilter(ignoreFilter) {}
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp,
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                 PatternRewriter &rewriter) const override {
-    // TODO: Support splitting TransferReadOp with non-identity
-    // permutation maps. Repurpose code from MaterializeVectors transformation.
-    if (!isIdentitySuffix(xferReadOp.permutation_map()))
+    if (ignoreFilter && ignoreFilter(readOp))
+      return failure();
+
+    // TODO: Support splitting TransferReadOp with non-identity permutation
+    // maps. Repurpose code from MaterializeVectors transformation.
+    if (!isIdentitySuffix(readOp.permutation_map()))
+      return failure();
+
+    // Return unless there is only one user, and it is an ExtractSlicesOp.
+    Value readResult = readOp.getResult();
+    if (!readResult.hasOneUse())
       return failure();
-    // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp.
-    Value xferReadResult = xferReadOp.getResult();
+
     auto extractSlicesOp =
-        dyn_cast<vector::ExtractSlicesOp>(*xferReadResult.getUsers().begin());
-    if (!xferReadResult.hasOneUse() || !extractSlicesOp)
+        dyn_cast<vector::ExtractSlicesOp>(readResult.use_begin()->getOwner());
+    if (!extractSlicesOp)
       return failure();
 
     // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user.
@@ -730,37 +741,48 @@ struct SplitTransferReadOp : public OpRewritePattern<vector::TransferReadOp> {
     extractSlicesOp.getStrides(strides);
     assert(llvm::all_of(strides, [](int64_t s) { return s == 1; }));
 
-    Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter);
+    Value newVec = unrollTransferReadOp(readOp, sizes, rewriter);
     if (!newVec)
       return failure();
-    rewriter.replaceOp(xferReadOp, newVec);
+    rewriter.replaceOp(readOp, newVec);
     return success();
   }
+
+private:
+  std::function<bool(Operation *)> ignoreFilter;
 };
 
-// Splits vector TransferWriteOp into smaller TransferWriteOps for each source.
-struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+// Splits a TransferWriteOp into smaller TransferWriteOps for each source.
+class SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
+public:
+  SplitTransferWriteOp(MLIRContext *context,
+                       std::function<bool(Operation *)> ignoreFilter = nullptr,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), ignoreFilter(ignoreFilter) {}
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp,
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                 PatternRewriter &rewriter) const override {
-    // TODO: Support splitting TransferWriteOp with non-identity
-    // permutation maps. Repurpose code from MaterializeVectors transformation.
-    if (!isIdentitySuffix(xferWriteOp.permutation_map()))
+    if (ignoreFilter && ignoreFilter(writeOp))
+      return failure();
+
+    // TODO: Support splitting TransferWriteOp with non-identity permutation
+    // maps. Repurpose code from MaterializeVectors transformation.
+    if (!isIdentitySuffix(writeOp.permutation_map()))
       return failure();
-    // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'.
-    auto *vectorDefOp = xferWriteOp.vector().getDefiningOp();
-    auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(vectorDefOp);
+
+    // Fail to match unless this is writing a vector resulting from an
+    // InsertSlicesOp.
+    auto insertSlicesOp =
+        writeOp.vector().getDefiningOp<vector::InsertSlicesOp>();
     if (!insertSlicesOp)
       return failure();
 
-    // Get TupleOp operand of 'insertSlicesOp'.
-    auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
-        insertSlicesOp.vectors().getDefiningOp());
+    // Get the TupleOp operand of the InsertSlicesOp.
+    auto tupleOp = insertSlicesOp.vectors().getDefiningOp<vector::TupleOp>();
     if (!tupleOp)
       return failure();
 
-    // Get 'sizes' and 'strides' parameters from InsertSlicesOp user.
+    // Get 'sizes' and 'strides' parameters from the InsertSlicesOp user.
     auto sourceTupleType = insertSlicesOp.getSourceTupleType();
     auto resultVectorType = insertSlicesOp.getResultVectorType();
     SmallVector<int64_t, 4> sizes;
@@ -768,21 +790,20 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
     SmallVector<int64_t, 4> strides;
     insertSlicesOp.getStrides(strides);
 
-    Location loc = xferWriteOp.getLoc();
+    Location loc = writeOp.getLoc();
     auto shapedElementType =
-        xferWriteOp.source().getType().cast<ShapedType>().getElementType();
-    SmallVector<Value, 4> indices(xferWriteOp.indices().begin(),
-                                  xferWriteOp.indices().end());
+        writeOp.source().getType().cast<ShapedType>().getElementType();
+    auto indices = llvm::to_vector<4>(writeOp.indices());
     Value resultTensor;
     auto createSlice = [&](unsigned index, ArrayRef<Value> sliceIndices) {
       // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'.
-      // `masked` attribute propagates conservatively: if the coarse op didn't
+      // 'masked' attribute propagates conservatively: if the coarse op didn't
       // need masking, the fine op doesn't either.
       Operation *write = rewriter.create<vector::TransferWriteOp>(
           loc, tupleOp.getOperand(index),
-          resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices,
-          xferWriteOp.permutation_map(),
-          xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr());
+          resultTensor ? resultTensor : writeOp.source(), sliceIndices,
+          writeOp.permutation_map(),
+          writeOp.masked() ? *writeOp.masked() : ArrayAttr());
       if (!write->getResults().empty())
         resultTensor = write->getResult(0);
     };
@@ -790,13 +811,15 @@ struct SplitTransferWriteOp : public OpRewritePattern<vector::TransferWriteOp> {
                              sourceTupleType, sizes, strides, indices, rewriter,
                              createSlice);
 
-    // Erase old 'xferWriteOp'.
     if (resultTensor)
-      rewriter.replaceOp(xferWriteOp, ArrayRef<Value>(resultTensor));
+      rewriter.replaceOp(writeOp, ArrayRef<Value>(resultTensor));
     else
-      rewriter.eraseOp(xferWriteOp);
+      rewriter.eraseOp(writeOp);
     return success();
   }
+
+private:
+  std::function<bool(Operation *)> ignoreFilter;
 };
 
 /// Decomposes ShapeCastOp on tuple-of-vectors to multiple ShapeCastOps, each
@@ -3105,15 +3128,16 @@ struct BubbleUpBitCastForStridedSliceInsert
 // TODO: Add this as DRR pattern.
 void mlir::vector::populateVectorToVectorTransformationPatterns(
     OwningRewritePatternList &patterns, MLIRContext *context) {
-  // clang-format off
-  patterns.insert<ShapeCastOpDecomposer,
-                  ShapeCastOpFolder,
-                  SplitTransferReadOp,
-                  SplitTransferWriteOp,
-                  TupleGetFolderOp,
-                  TransferReadExtractPattern,
-                  TransferWriteInsertPattern>(context);
-  // clang-format on
+  patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
+                  TransferReadExtractPattern, TransferWriteInsertPattern>(
+      context);
+}
+
+void mlir::vector::populateSplitVectorTransferPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *context,
+    std::function<bool(Operation *)> ignoreFilter) {
+  patterns.insert<SplitTransferReadOp, SplitTransferWriteOp>(context,
+                                                             ignoreFilter);
 }
 
 void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 8ec970f68b23..17da2d42fa32 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -47,6 +47,7 @@ struct TestVectorToVectorConversion
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     populateBubbleVectorBitCastOpPatterns(patterns, ctx);
     populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
+    populateSplitVectorTransferPatterns(patterns, ctx);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }
 


        


More information about the Mlir-commits mailing list