[Mlir-commits] [mlir] 59d3a9e - [mlir][vector] Separate high-D insert/extract strided slice rewrite
Lei Zhang
llvmlistbot at llvm.org
Tue Apr 5 12:01:15 PDT 2022
Author: Lei Zhang
Date: 2022-04-05T15:00:50-04:00
New Revision: 59d3a9e0877b2b12fc98eea0f9bbbc93f3c7a094
URL: https://github.com/llvm/llvm-project/commit/59d3a9e0877b2b12fc98eea0f9bbbc93f3c7a094
DIFF: https://github.com/llvm/llvm-project/commit/59d3a9e0877b2b12fc98eea0f9bbbc93f3c7a094.diff
LOG: [mlir][vector] Separate high-D insert/extract strided slice rewrite
Right now `populateVectorInsertExtractStridedSliceTransforms` contains
two categories of patterns, one for decomposing high-D insert/extract
strided slices, the other for lowering them to shuffle ops.
They are at different levels---the former is in the middle, while
the latter is a step of final lowering. Split them to give users
more control of which pattern to pick.
This means break down the previous `VectorExtractStridedSliceOpRewritePattern`,
which is doing two things together.
Also renamed those patterns to be clearer.
Reviewed By: ThomasRaoux
Differential Revision: https://reviews.llvm.org/D123137
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 0a298ea9e91a5..0522ef58bc812 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -241,8 +241,8 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
/// Populate `patterns` with the following patterns.
///
-/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
-/// =======================================================
+/// [DecomposeDifferentRankInsertStridedSlice]
+/// ==========================================
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have
diff erent ranks.
///
@@ -257,8 +257,19 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
/// 2. k-D -> (n-1)-D InsertStridedSlice op
/// 3. InsertOp that is the reverse of 1.
///
-/// [VectorInsertStridedSliceOpSameRankRewritePattern]
-/// ==================================================
+/// [DecomposeNDExtractStridedSlice]
+/// ================================
+/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
+/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+void populateVectorInsertExtractStridedSliceDecompositionPatterns(
+ RewritePatternSet &patterns);
+
+/// Populate `patterns` with the following patterns.
+///
+/// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns();
+///
+/// [ConvertSameRankInsertStridedSliceIntoShuffle]
+/// ==============================================
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have the same rank. For each outermost index in the slice:
/// begin end stride
@@ -268,12 +279,9 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
/// 3. the destination subvector is inserted back in the proper place
/// 3. InsertOp that is the reverse of 1.
///
-/// [VectorExtractStridedSliceOpRewritePattern]
-/// ===========================================
-/// Progressive lowering of ExtractStridedSliceOp to either:
-/// 1. single offset extract as a direct vector::ShuffleOp.
-/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
-/// InsertOp/InsertElementOp for the n-D case.
+/// [Convert1DExtractStridedSliceIntoShuffle]
+/// =========================================
+/// For such cases, we can lower it to a ShuffleOp.
void populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 2a384c3bf7853..a1e80e1fc3743 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -45,14 +45,14 @@ static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
/// When ranks are
diff erent, InsertStridedSlice needs to extract a properly
/// ranked vector from the destination vector into which to insert. This pattern
/// only takes care of this extraction part and forwards the rest to
-/// [VectorInsertStridedSliceOpSameRankRewritePattern].
+/// [ConvertSameRankInsertStridedSliceIntoShuffle].
///
/// For a k-D source and n-D destination vector (k < n), we emit:
/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
/// insert the k-D source.
/// 2. k-D -> (n-1)-D InsertStridedSlice op
/// 3. InsertOp that is the reverse of 1.
-class VectorInsertStridedSliceOpDifferentRankRewritePattern
+class DecomposeDifferentRankInsertStridedSlice
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
@@ -102,7 +102,7 @@ class VectorInsertStridedSliceOpDifferentRankRewritePattern
/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
/// 3. the destination subvector is inserted back in the proper place
/// 3. InsertOp that is the reverse of 1.
-class VectorInsertStridedSliceOpSameRankRewritePattern
+class ConvertSameRankInsertStridedSliceIntoShuffle
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
@@ -193,11 +193,50 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
}
};
-/// Progressive lowering of ExtractStridedSliceOp to either:
-/// 1. single offset extract as a direct vector::ShuffleOp.
-/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
-/// InsertOp/InsertElementOp for the n-D case.
-class VectorExtractStridedSliceOpRewritePattern
+/// RewritePattern for ExtractStridedSliceOp where source and destination
+/// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
+class Convert1DExtractStridedSliceIntoShuffle
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto dstType = op.getType();
+
+ assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
+
+ int64_t offset =
+ op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size =
+ op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t stride =
+ op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
+
+ auto loc = op.getLoc();
+ auto elemType = dstType.getElementType();
+ assert(elemType.isSignlessIntOrIndexOrFloat());
+
+ // Single offset can be more efficiently shuffled.
+ if (op.getOffsets().getValue().size() != 1)
+ return failure();
+
+ SmallVector<int64_t, 4> offsets;
+ offsets.reserve(size);
+ for (int64_t off = offset, e = offset + size * stride; off < e;
+ off += stride)
+ offsets.push_back(off);
+ rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
+ op.getVector(),
+ rewriter.getI64ArrayAttr(offsets));
+ return success();
+ }
+};
+
+/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
+/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
+/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
+class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
@@ -225,18 +264,10 @@ class VectorExtractStridedSliceOpRewritePattern
auto elemType = dstType.getElementType();
assert(elemType.isSignlessIntOrIndexOrFloat());
- // Single offset can be more efficiently shuffled.
- if (op.getOffsets().getValue().size() == 1) {
- SmallVector<int64_t, 4> offsets;
- offsets.reserve(size);
- for (int64_t off = offset, e = offset + size * stride; off < e;
- off += stride)
- offsets.push_back(off);
- rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
- op.getVector(),
- rewriter.getI64ArrayAttr(offsets));
- return success();
- }
+ // Single offset can be more efficiently shuffled. It's handled in
+ // Convert1DExtractStridedSliceIntoShuffle.
+ if (op.getOffsets().getValue().size() == 1)
+ return failure();
// Extract/insert on a lower ranked extract strided slice op.
Value zero = rewriter.create<arith::ConstantOp>(
@@ -256,11 +287,16 @@ class VectorExtractStridedSliceOpRewritePattern
}
};
+void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<DecomposeDifferentRankInsertStridedSlice,
+ DecomposeNDExtractStridedSlice>(patterns.getContext());
+}
+
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns) {
- patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
- VectorInsertStridedSliceOpSameRankRewritePattern,
- VectorExtractStridedSliceOpRewritePattern>(
- patterns.getContext());
+ populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns);
+ patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
+ Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext());
}
More information about the Mlir-commits
mailing list