[Mlir-commits] [mlir] 2a51e9f - [mlir] Support memref layout maps in vector transfer ops
Matthias Springer
llvmlistbot at llvm.org
Wed May 12 21:22:40 PDT 2021
Author: Matthias Springer
Date: 2021-05-13T13:22:21+09:00
New Revision: 2a51e9ff2e06d5d7096f826014916b4cc02269fc
URL: https://github.com/llvm/llvm-project/commit/2a51e9ff2e06d5d7096f826014916b4cc02269fc
DIFF: https://github.com/llvm/llvm-project/commit/2a51e9ff2e06d5d7096f826014916b4cc02269fc.diff
LOG: [mlir] Support memref layout maps in vector transfer ops
Differential Revision: https://reviews.llvm.org/D102042
Added:
Modified:
mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index 7a016bd88547..7b976cc3c2a5 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -1013,6 +1013,14 @@ struct Strategy1d<TransferWriteOp> {
}
};
+/// Return true if the last dimension of the MemRefType has unit stride.
+static bool isLastMemrefDimUnitStride(MemRefType type) {
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ auto successStrides = getStridesAndOffset(type, strides, offset);
+ return succeeded(successStrides) && strides.back() == 1;
+}
+
/// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
/// necessary in cases where a 1D vector transfer op cannot be lowered into
/// vector load/stores due to non-unit strides or broadcasts:
@@ -1052,11 +1060,14 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
PatternRewriter &rewriter) const override {
ScopedContext scope(rewriter, xferOp.getLoc());
auto map = xferOp.permutation_map();
+ auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
+ if (!memRefType)
+ return failure();
if (xferOp.getVectorType().getRank() != 1)
- return failure();
- if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
- return failure();
+ return failure();
+ if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
+ return failure(); // Handled by ConvertVectorToLLVM
// Loop bounds, step, state...
auto vecType = xferOp.getVectorType();
More information about the Mlir-commits
mailing list