[Mlir-commits] [mlir] eda2ebd - [mlir][Vector] NFC - Extract rewrites related to insert/extract strided slice in a separate file.
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 22 03:03:51 PDT 2021
Author: Nicolas Vasilache
Date: 2021-10-22T10:03:33Z
New Revision: eda2ebd7807376829eb880c39623f364b438971f
URL: https://github.com/llvm/llvm-project/commit/eda2ebd7807376829eb880c39623f364b438971f
DIFF: https://github.com/llvm/llvm-project/commit/eda2ebd7807376829eb880c39623f364b438971f.diff
LOG: [mlir][Vector] NFC - Extract rewrites related to insert/extract strided slice in a separate file.
Differential Revision: https://reviews.llvm.org/D112301
Added:
mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
Modified:
mlir/include/mlir/Dialect/Vector/VectorTransforms.h
mlir/include/mlir/Dialect/Vector/VectorUtils.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/CMakeLists.txt
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/Dialect/Vector/VectorUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
new file mode 100644
index 0000000000000..13b310713f7b5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h
@@ -0,0 +1,58 @@
+//===- VectorRewritePatterns.h - Vector rewrite patterns --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
+#define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
+
+namespace mlir {
+class RewritePatternSet;
+
+namespace vector {
+
+/// Populate `patterns` with the following patterns.
+///
+/// [VectorInsertStridedSliceOpDifferentRankRewritePattern]
+/// =======================================================
+/// RewritePattern for InsertStridedSliceOp where source and destination vectors
+/// have
diff erent ranks.
+///
+/// When ranks are
diff erent, InsertStridedSlice needs to extract a properly
+/// ranked vector from the destination vector into which to insert. This pattern
+/// only takes care of this extraction part and forwards the rest to
+/// [VectorInsertStridedSliceOpSameRankRewritePattern].
+///
+/// For a k-D source and n-D destination vector (k < n), we emit:
+/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
+/// insert the k-D source.
+/// 2. k-D -> (n-1)-D InsertStridedSlice op
+/// 3. InsertOp that is the reverse of 1.
+///
+/// [VectorInsertStridedSliceOpSameRankRewritePattern]
+/// ==================================================
+/// RewritePattern for InsertStridedSliceOp where source and destination vectors
+/// have the same rank. For each outermost index in the slice:
+/// begin end stride
+/// [offset : offset+size*stride : stride]
+/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
+/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
+/// 3. the destination subvector is inserted back in the proper place
+/// 3. InsertOp that is the reverse of 1.
+///
+/// [VectorExtractStridedSliceOpRewritePattern]
+/// ===========================================
+/// Progressive lowering of ExtractStridedSliceOp to either:
+/// 1. single offset extract as a direct vector::ShuffleOp.
+/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
+/// InsertOp/InsertElementOp for the n-D case.
+void populateVectorInsertExtractStridedSliceTransforms(
+ RewritePatternSet &patterns);
+
+} // namespace vector
+} // namespace mlir
+
+#endif // DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_
diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index 59e6ac07bbca3..d26636c132ac4 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -24,13 +24,6 @@ namespace scf {
class IfOp;
} // namespace scf
-/// Collect a set of patterns to convert from the Vector dialect to itself.
-/// Should be merged with populateVectorToSCFLoweringPattern.
-void populateVectorToVectorConversionPatterns(
- MLIRContext *context, RewritePatternSet &patterns,
- ArrayRef<int64_t> coarseVectorShape = {},
- ArrayRef<int64_t> fineVectorShape = {});
-
namespace vector {
/// Options that control the vector unrolling.
diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 2df0229ec2f4a..788e0c8316e23 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_VECTORUTILS_H_
#define MLIR_DIALECT_VECTOR_VECTORUTILS_H_
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
@@ -184,6 +185,11 @@ bool checkSameValueRAW(vector::TransferWriteOp defWrite,
bool checkSameValueWAW(vector::TransferWriteOp write,
vector::TransferWriteOp priorWrite);
+// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
+SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
+ unsigned dropFront = 0,
+ unsigned dropBack = 0);
+
namespace matcher {
/// Matches vector.transfer_read, vector.transfer_write and ops that return a
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a6f25332d1331..77d2a46977172 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
@@ -52,17 +53,6 @@ static Value insertOne(ConversionPatternRewriter &rewriter,
rewriter.getI64ArrayAttr(pos));
}
-// Helper that picks the proper sequence for inserting.
-static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
- Value into, int64_t offset) {
- auto vectorType = into.getType().cast<VectorType>();
- if (vectorType.getRank() > 1)
- return rewriter.create<InsertOp>(loc, from, into, offset);
- return rewriter.create<vector::InsertElementOp>(
- loc, vectorType, from, into,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
-}
-
// Helper that picks the proper sequence for extracting.
static Value extractOne(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
@@ -79,32 +69,6 @@ static Value extractOne(ConversionPatternRewriter &rewriter,
rewriter.getI64ArrayAttr(pos));
}
-// Helper that picks the proper sequence for extracting.
-static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
- int64_t offset) {
- auto vectorType = vector.getType().cast<VectorType>();
- if (vectorType.getRank() > 1)
- return rewriter.create<ExtractOp>(loc, vector, offset);
- return rewriter.create<vector::ExtractElementOp>(
- loc, vectorType.getElementType(), vector,
- rewriter.create<arith::ConstantIndexOp>(loc, offset));
-}
-
-// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
-// TODO: Better support for attribute subtype forwarding + slicing.
-static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
- unsigned dropFront = 0,
- unsigned dropBack = 0) {
- assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
- auto range = arrayAttr.getAsRange<IntegerAttr>();
- SmallVector<int64_t, 4> res;
- res.reserve(arrayAttr.size() - dropFront - dropBack);
- for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
- it != eit; ++it)
- res.push_back((*it).getValue().getSExtValue());
- return res;
-}
-
// Helper that returns data layout alignment of a memref.
LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
MemRefType memrefType, unsigned &align) {
@@ -813,132 +777,6 @@ class VectorFMAOpNDRewritePattern : public OpRewritePattern<FMAOp> {
}
};
-// When ranks are
diff erent, InsertStridedSlice needs to extract a properly
-// ranked vector from the destination vector into which to insert. This pattern
-// only takes care of this part and forwards the rest of the conversion to
-// another pattern that converts InsertStridedSlice for operands of the same
-// rank.
-//
-// RewritePattern for InsertStridedSliceOp where source and destination vectors
-// have
diff erent ranks. In this case:
-// 1. the proper subvector is extracted from the destination vector
-// 2. a new InsertStridedSlice op is created to insert the source in the
-// destination subvector
-// 3. the destination subvector is inserted back in the proper place
-// 4. the op is replaced by the result of step 3.
-// The new InsertStridedSlice from step 2. will be picked up by a
-// `VectorInsertStridedSliceOpSameRankRewritePattern`.
-class VectorInsertStridedSliceOpDifferentRankRewritePattern
- : public OpRewritePattern<InsertStridedSliceOp> {
-public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(InsertStridedSliceOp op,
- PatternRewriter &rewriter) const override {
- auto srcType = op.getSourceVectorType();
- auto dstType = op.getDestVectorType();
-
- if (op.offsets().getValue().empty())
- return failure();
-
- auto loc = op.getLoc();
- int64_t rankDiff = dstType.getRank() - srcType.getRank();
- assert(rankDiff >= 0);
- if (rankDiff == 0)
- return failure();
-
- int64_t rankRest = dstType.getRank() - rankDiff;
- // Extract / insert the subvector of matching rank and InsertStridedSlice
- // on it.
- Value extracted =
- rewriter.create<ExtractOp>(loc, op.dest(),
- getI64SubArray(op.offsets(), /*dropFront=*/0,
- /*dropBack=*/rankRest));
- // A
diff erent pattern will kick in for InsertStridedSlice with matching
- // ranks.
- auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
- loc, op.source(), extracted,
- getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
- getI64SubArray(op.strides(), /*dropFront=*/0));
- rewriter.replaceOpWithNewOp<InsertOp>(
- op, stridedSliceInnerOp.getResult(), op.dest(),
- getI64SubArray(op.offsets(), /*dropFront=*/0,
- /*dropBack=*/rankRest));
- return success();
- }
-};
-
-// RewritePattern for InsertStridedSliceOp where source and destination vectors
-// have the same rank. In this case, we reduce
-// 1. the proper subvector is extracted from the destination vector
-// 2. a new InsertStridedSlice op is created to insert the source in the
-// destination subvector
-// 3. the destination subvector is inserted back in the proper place
-// 4. the op is replaced by the result of step 3.
-// The new InsertStridedSlice from step 2. will be picked up by a
-// `VectorInsertStridedSliceOpSameRankRewritePattern`.
-class VectorInsertStridedSliceOpSameRankRewritePattern
- : public OpRewritePattern<InsertStridedSliceOp> {
-public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
-
- void initialize() {
- // This pattern creates recursive InsertStridedSliceOp, but the recursion is
- // bounded as the rank is strictly decreasing.
- setHasBoundedRewriteRecursion();
- }
-
- LogicalResult matchAndRewrite(InsertStridedSliceOp op,
- PatternRewriter &rewriter) const override {
- auto srcType = op.getSourceVectorType();
- auto dstType = op.getDestVectorType();
-
- if (op.offsets().getValue().empty())
- return failure();
-
- int64_t rankDiff = dstType.getRank() - srcType.getRank();
- assert(rankDiff >= 0);
- if (rankDiff != 0)
- return failure();
-
- if (srcType == dstType) {
- rewriter.replaceOp(op, op.source());
- return success();
- }
-
- int64_t offset =
- op.offsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size = srcType.getShape().front();
- int64_t stride =
- op.strides().getValue().front().cast<IntegerAttr>().getInt();
-
- auto loc = op.getLoc();
- Value res = op.dest();
- // For each slice of the source vector along the most major dimension.
- for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
- off += stride, ++idx) {
- // 1. extract the proper subvector (or element) from source
- Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
- if (extractedSource.getType().isa<VectorType>()) {
- // 2. If we have a vector, extract the proper subvector from destination
- // Otherwise we are at the element level and no need to recurse.
- Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
- // 3. Reduce the problem to lowering a new InsertStridedSlice op with
- // smaller rank.
- extractedSource = rewriter.create<InsertStridedSliceOp>(
- loc, extractedSource, extractedDest,
- getI64SubArray(op.offsets(), /* dropFront=*/1),
- getI64SubArray(op.strides(), /* dropFront=*/1));
- }
- // 4. Insert the extractedSource into the res vector.
- res = insertOne(rewriter, loc, extractedSource, res, off);
- }
-
- rewriter.replaceOp(op, res);
- return success();
- }
-};
-
/// Returns the strides if the memory underlying `memRefType` has a contiguous
/// static layout.
static llvm::Optional<SmallVector<int64_t, 4>>
@@ -1189,67 +1027,6 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}
};
-/// Progressive lowering of ExtractStridedSliceOp to either:
-/// 1. express single offset extract as a direct shuffle.
-/// 2. extract + lower rank strided_slice + insert for the n-D case.
-class VectorExtractStridedSliceOpConversion
- : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
-
- void initialize() {
- // This pattern creates recursive ExtractStridedSliceOp, but the recursion
- // is bounded as the rank is strictly decreasing.
- setHasBoundedRewriteRecursion();
- }
-
- LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
- PatternRewriter &rewriter) const override {
- auto dstType = op.getType();
-
- assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
-
- int64_t offset =
- op.offsets().getValue().front().cast<IntegerAttr>().getInt();
- int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
- int64_t stride =
- op.strides().getValue().front().cast<IntegerAttr>().getInt();
-
- auto loc = op.getLoc();
- auto elemType = dstType.getElementType();
- assert(elemType.isSignlessIntOrIndexOrFloat());
-
- // Single offset can be more efficiently shuffled.
- if (op.offsets().getValue().size() == 1) {
- SmallVector<int64_t, 4> offsets;
- offsets.reserve(size);
- for (int64_t off = offset, e = offset + size * stride; off < e;
- off += stride)
- offsets.push_back(off);
- rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
- op.vector(),
- rewriter.getI64ArrayAttr(offsets));
- return success();
- }
-
- // Extract/insert on a lower ranked extract strided slice op.
- Value zero = rewriter.create<arith::ConstantOp>(
- loc, elemType, rewriter.getZeroAttr(elemType));
- Value res = rewriter.create<SplatOp>(loc, dstType, zero);
- for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
- off += stride, ++idx) {
- Value one = extractOne(rewriter, loc, op.vector(), off);
- Value extracted = rewriter.create<ExtractStridedSliceOp>(
- loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
- getI64SubArray(op.sizes(), /* dropFront=*/1),
- getI64SubArray(op.strides(), /* dropFront=*/1));
- res = insertOne(rewriter, loc, extracted, res, idx);
- }
- rewriter.replaceOp(op, res);
- return success();
- }
-};
-
} // namespace
/// Populate the given list with patterns that convert from Vector to LLVM.
@@ -1257,10 +1034,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions) {
MLIRContext *ctx = converter.getDialect()->getContext();
- patterns.add<VectorFMAOpNDRewritePattern,
- VectorInsertStridedSliceOpDifferentRankRewritePattern,
- VectorInsertStridedSliceOpSameRankRewritePattern,
- VectorExtractStridedSliceOpConversion>(ctx);
+ patterns.add<VectorFMAOpNDRewritePattern>(ctx);
+ populateVectorInsertExtractStridedSliceTransforms(patterns);
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 594e8cd3bb7d3..f620a370c8359 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRVector
- VectorOps.cpp
+ VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorMultiDimReductionTransforms.cpp
+ VectorOps.cpp
VectorTransferOpTransforms.cpp
VectorTransforms.cpp
VectorUtils.cpp
diff --git a/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
new file mode 100644
index 0000000000000..1dc04027266a5
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -0,0 +1,236 @@
+//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+// Helper that picks the proper sequence for inserting.
+static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
+ Value into, int64_t offset) {
+ auto vectorType = into.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<InsertOp>(loc, from, into, offset);
+ return rewriter.create<vector::InsertElementOp>(
+ loc, vectorType, from, into,
+ rewriter.create<arith::ConstantIndexOp>(loc, offset));
+}
+
+// Helper that picks the proper sequence for extracting.
+static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
+ int64_t offset) {
+ auto vectorType = vector.getType().cast<VectorType>();
+ if (vectorType.getRank() > 1)
+ return rewriter.create<ExtractOp>(loc, vector, offset);
+ return rewriter.create<vector::ExtractElementOp>(
+ loc, vectorType.getElementType(), vector,
+ rewriter.create<arith::ConstantIndexOp>(loc, offset));
+}
+
+/// RewritePattern for InsertStridedSliceOp where source and destination vectors
+/// have
diff erent ranks.
+///
+/// When ranks are
diff erent, InsertStridedSlice needs to extract a properly
+/// ranked vector from the destination vector into which to insert. This pattern
+/// only takes care of this extraction part and forwards the rest to
+/// [VectorInsertStridedSliceOpSameRankRewritePattern].
+///
+/// For a k-D source and n-D destination vector (k < n), we emit:
+/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
+/// insert the k-D source.
+/// 2. k-D -> (n-1)-D InsertStridedSlice op
+/// 3. InsertOp that is the reverse of 1.
+class VectorInsertStridedSliceOpDifferentRankRewritePattern
+ : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+ using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto srcType = op.getSourceVectorType();
+ auto dstType = op.getDestVectorType();
+
+ if (op.offsets().getValue().empty())
+ return failure();
+
+ auto loc = op.getLoc();
+ int64_t rankDiff = dstType.getRank() - srcType.getRank();
+ assert(rankDiff >= 0);
+ if (rankDiff == 0)
+ return failure();
+
+ int64_t rankRest = dstType.getRank() - rankDiff;
+ // Extract / insert the subvector of matching rank and InsertStridedSlice
+ // on it.
+ Value extracted =
+ rewriter.create<ExtractOp>(loc, op.dest(),
+ getI64SubArray(op.offsets(), /*dropFront=*/0,
+ /*dropBack=*/rankRest));
+
+ // A
diff erent pattern will kick in for InsertStridedSlice with matching
+ // ranks.
+ auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
+ loc, op.source(), extracted,
+ getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
+ getI64SubArray(op.strides(), /*dropFront=*/0));
+
+ rewriter.replaceOpWithNewOp<InsertOp>(
+ op, stridedSliceInnerOp.getResult(), op.dest(),
+ getI64SubArray(op.offsets(), /*dropFront=*/0,
+ /*dropBack=*/rankRest));
+ return success();
+ }
+};
+
+/// RewritePattern for InsertStridedSliceOp where source and destination vectors
+/// have the same rank. For each outermost index in the slice:
+/// begin end stride
+/// [offset : offset+size*stride : stride]
+/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
+/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
+/// 3. the destination subvector is inserted back in the proper place
+/// 3. InsertOp that is the reverse of 1.
+class VectorInsertStridedSliceOpSameRankRewritePattern
+ : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+ using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+ void initialize() {
+ // This pattern creates recursive InsertStridedSliceOp, but the recursion is
+ // bounded as the rank is strictly decreasing.
+ setHasBoundedRewriteRecursion();
+ }
+
+ LogicalResult matchAndRewrite(InsertStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto srcType = op.getSourceVectorType();
+ auto dstType = op.getDestVectorType();
+
+ if (op.offsets().getValue().empty())
+ return failure();
+
+ int64_t rankDiff = dstType.getRank() - srcType.getRank();
+ assert(rankDiff >= 0);
+ if (rankDiff != 0)
+ return failure();
+
+ if (srcType == dstType) {
+ rewriter.replaceOp(op, op.source());
+ return success();
+ }
+
+ int64_t offset =
+ op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size = srcType.getShape().front();
+ int64_t stride =
+ op.strides().getValue().front().cast<IntegerAttr>().getInt();
+
+ auto loc = op.getLoc();
+ Value res = op.dest();
+ // For each slice of the source vector along the most major dimension.
+ for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
+ off += stride, ++idx) {
+ // 1. extract the proper subvector (or element) from source
+ Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
+ if (extractedSource.getType().isa<VectorType>()) {
+ // 2. If we have a vector, extract the proper subvector from destination
+ // Otherwise we are at the element level and no need to recurse.
+ Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
+ // 3. Reduce the problem to lowering a new InsertStridedSlice op with
+ // smaller rank.
+ extractedSource = rewriter.create<InsertStridedSliceOp>(
+ loc, extractedSource, extractedDest,
+ getI64SubArray(op.offsets(), /* dropFront=*/1),
+ getI64SubArray(op.strides(), /* dropFront=*/1));
+ }
+ // 4. Insert the extractedSource into the res vector.
+ res = insertOne(rewriter, loc, extractedSource, res, off);
+ }
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+};
+
+/// Progressive lowering of ExtractStridedSliceOp to either:
+/// 1. single offset extract as a direct vector::ShuffleOp.
+/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
+/// InsertOp/InsertElementOp for the n-D case.
+class VectorExtractStridedSliceOpRewritePattern
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+ void initialize() {
+ // This pattern creates recursive ExtractStridedSliceOp, but the recursion
+ // is bounded as the rank is strictly decreasing.
+ setHasBoundedRewriteRecursion();
+ }
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ auto dstType = op.getType();
+
+ assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
+
+ int64_t offset =
+ op.offsets().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
+ int64_t stride =
+ op.strides().getValue().front().cast<IntegerAttr>().getInt();
+
+ auto loc = op.getLoc();
+ auto elemType = dstType.getElementType();
+ assert(elemType.isSignlessIntOrIndexOrFloat());
+
+ // Single offset can be more efficiently shuffled.
+ if (op.offsets().getValue().size() == 1) {
+ SmallVector<int64_t, 4> offsets;
+ offsets.reserve(size);
+ for (int64_t off = offset, e = offset + size * stride; off < e;
+ off += stride)
+ offsets.push_back(off);
+ rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
+ op.vector(),
+ rewriter.getI64ArrayAttr(offsets));
+ return success();
+ }
+
+ // Extract/insert on a lower ranked extract strided slice op.
+ Value zero = rewriter.create<arith::ConstantOp>(
+ loc, elemType, rewriter.getZeroAttr(elemType));
+ Value res = rewriter.create<SplatOp>(loc, dstType, zero);
+ for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
+ off += stride, ++idx) {
+ Value one = extractOne(rewriter, loc, op.vector(), off);
+ Value extracted = rewriter.create<ExtractStridedSliceOp>(
+ loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
+ getI64SubArray(op.sizes(), /* dropFront=*/1),
+ getI64SubArray(op.strides(), /* dropFront=*/1));
+ res = insertOne(rewriter, loc, extracted, res, idx);
+ }
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+};
+
+/// Populate the given list with patterns that convert from Vector to LLVM.
+void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
+ RewritePatternSet &patterns) {
+ patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
+ VectorInsertStridedSliceOpSameRankRewritePattern,
+ VectorExtractStridedSliceOpRewritePattern>(
+ patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 769a416278150..b0ba6a8f94fb5 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2204,20 +2204,6 @@ class StridedSliceConstantFolder final
}
};
-// Helper that returns a subset of `arrayAttr` as a vector of int64_t.
-static SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
- unsigned dropFront = 0,
- unsigned dropBack = 0) {
- assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
- auto range = arrayAttr.getAsRange<IntegerAttr>();
- SmallVector<int64_t, 4> res;
- res.reserve(arrayAttr.size() - dropFront - dropBack);
- for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
- it != eit; ++it)
- res.push_back((*it).getValue().getSExtValue());
- return res;
-}
-
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
// BroadcastOp(ExtractStrideSliceOp).
class StridedSliceBroadcast final
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 75ee23b5ffb3b..d98fa705dbf62 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -1034,10 +1034,11 @@ class ShapeCastOp2DDownCastRewritePattern
};
/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
-/// vectors progressively on the way from targeting llvm.matrix intrinsics.
+/// vectors progressively.
/// This iterates over the most major dimension of the 2-D vector and performs
/// rewrites into:
-/// vector.strided_slice from 1-D + vector.insert into 2-D
+/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
+/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
class ShapeCastOp2DUpCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 937e303968a36..2659a313c464b 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -362,3 +362,16 @@ bool mlir::checkSameValueWAW(vector::TransferWriteOp write,
priorWrite.getVectorType() == write.getVectorType() &&
priorWrite.permutation_map() == write.permutation_map();
}
+
+SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
+ unsigned dropFront,
+ unsigned dropBack) {
+ assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds");
+ auto range = arrayAttr.getAsRange<IntegerAttr>();
+ SmallVector<int64_t, 4> res;
+ res.reserve(arrayAttr.size() - dropFront - dropBack);
+ for (auto it = range.begin() + dropFront, eit = range.end() - dropBack;
+ it != eit; ++it)
+ res.push_back((*it).getValue().getSExtValue());
+ return res;
+}
More information about the Mlir-commits
mailing list