[Mlir-commits] [mlir] 59d3a9e - [mlir][vector] Separate high-D insert/extract strided slice rewrite

Lei Zhang llvmlistbot at llvm.org
Tue Apr 5 12:01:15 PDT 2022


Author: Lei Zhang
Date: 2022-04-05T15:00:50-04:00
New Revision: 59d3a9e0877b2b12fc98eea0f9bbbc93f3c7a094

URL: https://github.com/llvm/llvm-project/commit/59d3a9e0877b2b12fc98eea0f9bbbc93f3c7a094
DIFF: https://github.com/llvm/llvm-project/commit/59d3a9e0877b2b12fc98eea0f9bbbc93f3c7a094.diff

LOG: [mlir][vector] Separate high-D insert/extract strided slice rewrite

Right now `populateVectorInsertExtractStridedSliceTransforms` contains
two categories of patterns, one for decomposing high-D insert/extract
strided slices, the other for lowering them to shuffle ops.
They are at different levels---the former is in the middle, while
the latter is a step of final lowering. Split them to give users
more control of which pattern to pick.

This means break down the previous `VectorExtractStridedSliceOpRewritePattern`,
which is doing two things together.

Also renamed those patterns to be clearer.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D123137

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 0a298ea9e91a5..0522ef58bc812 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -241,8 +241,8 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 
 /// Populate `patterns` with the following patterns.
 ///
-/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
-/// =======================================================
+/// [DecomposeDifferentRankInsertStridedSlice]
+/// ==========================================
 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
 /// have 
diff erent ranks.
 ///
@@ -257,8 +257,19 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 ///   2. k-D -> (n-1)-D InsertStridedSlice op
 ///   3. InsertOp that is the reverse of 1.
 ///
-/// [VectorInsertStridedSliceOpSameRankRewritePattern]
-/// ==================================================
+/// [DecomposeNDExtractStridedSlice]
+/// ================================
+/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
+/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+void populateVectorInsertExtractStridedSliceDecompositionPatterns(
+    RewritePatternSet &patterns);
+
+/// Populate `patterns` with the following patterns.
+///
+/// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
+///
+/// [ConvertSameRankInsertStridedSliceIntoShuffle]
+/// ==============================================
 /// RewritePattern for InsertStridedSliceOp where source and destination vectors
 /// have the same rank. For each outermost index in the slice:
 ///   begin    end             stride
@@ -268,12 +279,9 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 ///   3. the destination subvector is inserted back in the proper place
 ///   3. InsertOp that is the reverse of 1.
 ///
-/// [VectorExtractStridedSliceOpRewritePattern]
-/// ===========================================
-/// Progressive lowering of ExtractStridedSliceOp to either:
-///   1. single offset extract as a direct vector::ShuffleOp.
-///   2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
-///      InsertOp/InsertElementOp for the n-D case.
+/// [Convert1DExtractStridedSliceIntoShuffle]
+/// =========================================
+/// For such cases, we can lower it to a ShuffleOp.
 void populateVectorInsertExtractStridedSliceTransforms(
     RewritePatternSet &patterns);
 

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 2a384c3bf7853..a1e80e1fc3743 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -45,14 +45,14 @@ static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
 /// When ranks are 
diff erent, InsertStridedSlice needs to extract a properly
 /// ranked vector from the destination vector into which to insert. This pattern
 /// only takes care of this extraction part and forwards the rest to
-/// [VectorInsertStridedSliceOpSameRankRewritePattern].
+/// [ConvertSameRankInsertStridedSliceIntoShuffle].
 ///
 /// For a k-D source and n-D destination vector (k < n), we emit:
 ///   1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
 ///      insert the k-D source.
 ///   2. k-D -> (n-1)-D InsertStridedSlice op
 ///   3. InsertOp that is the reverse of 1.
-class VectorInsertStridedSliceOpDifferentRankRewritePattern
+class DecomposeDifferentRankInsertStridedSlice
     : public OpRewritePattern<InsertStridedSliceOp> {
 public:
   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
@@ -102,7 +102,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
 ///   2. InsertStridedSlice (k-1)-D into (n-1)-D
 ///   3. the destination subvector is inserted back in the proper place
 ///   3. InsertOp that is the reverse of 1.
-class VectorInsertStridedSliceOpSameRankRewritePattern
+class ConvertSameRankInsertStridedSliceIntoShuffle
     : public OpRewritePattern<InsertStridedSliceOp> {
 public:
   using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
@@ -193,11 +193,50 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
   }
 };
 
-/// Progressive lowering of ExtractStridedSliceOp to either:
-///   1. single offset extract as a direct vector::ShuffleOp.
-///   2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
-///      InsertOp/InsertElementOp for the n-D case.
-class VectorExtractStridedSliceOpRewritePattern
+/// RewritePattern for ExtractStridedSliceOp where source and destination
+/// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
+class Convert1DExtractStridedSliceIntoShuffle
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    auto dstType = op.getType();
+
+    assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
+
+    int64_t offset =
+        op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
+    int64_t size =
+        op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+    int64_t stride =
+        op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+
+    auto loc = op.getLoc();
+    auto elemType = dstType.getElementType();
+    assert(elemType.isSignlessIntOrIndexOrFloat());
+
+    // Single offset can be more efficiently shuffled.
+    if (op.getOffsets().getValue().size() != 1)
+      return failure();
+
+    SmallVector<int64_t, 4> offsets;
+    offsets.reserve(size);
+    for (int64_t off = offset, e = offset + size * stride; off < e;
+         off += stride)
+      offsets.push_back(off);
+    rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
+                                           op.getVector(),
+                                           rewriter.getI64ArrayAttr(offsets));
+    return success();
+  }
+};
+
+/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
+/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
+/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+class DecomposeNDExtractStridedSlice
     : public OpRewritePattern<ExtractStridedSliceOp> {
 public:
   using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
@@ -225,18 +264,10 @@ class VectorExtractStridedSliceOpRewritePattern
     auto elemType = dstType.getElementType();
     assert(elemType.isSignlessIntOrIndexOrFloat());
 
-    // Single offset can be more efficiently shuffled.
-    if (op.getOffsets().getValue().size() == 1) {
-      SmallVector<int64_t, 4> offsets;
-      offsets.reserve(size);
-      for (int64_t off = offset, e = offset + size * stride; off < e;
-           off += stride)
-        offsets.push_back(off);
-      rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
-                                             op.getVector(),
-                                             rewriter.getI64ArrayAttr(offsets));
-      return success();
-    }
+    // Single offset can be more efficiently shuffled. It's handled in
+    // Convert1DExtractStridedSliceIntoShuffle.
+    if (op.getOffsets().getValue().size() == 1)
+      return failure();
 
     // Extract/insert on a lower ranked extract strided slice op.
     Value zero = rewriter.create<arith::ConstantOp>(
@@ -256,11 +287,16 @@ class VectorExtractStridedSliceOpRewritePattern
   }
 };
 
+void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DecomposeDifferentRankInsertStridedSlice,
+               DecomposeNDExtractStridedSlice>(patterns.getContext());
+}
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
     RewritePatternSet &patterns) {
-  patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
-               VectorInsertStridedSliceOpSameRankRewritePattern,
-               VectorExtractStridedSliceOpRewritePattern>(
-      patterns.getContext());
+  populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns);
+  patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
+               Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext());
 }


        


More information about the Mlir-commits mailing list