[Mlir-commits] [mlir] 2a51e9f - [mlir] Support memref layout maps in vector transfer ops

Matthias Springer llvmlistbot at llvm.org
Wed May 12 21:22:40 PDT 2021


Author: Matthias Springer
Date: 2021-05-13T13:22:21+09:00
New Revision: 2a51e9ff2e06d5d7096f826014916b4cc02269fc

URL: https://github.com/llvm/llvm-project/commit/2a51e9ff2e06d5d7096f826014916b4cc02269fc
DIFF: https://github.com/llvm/llvm-project/commit/2a51e9ff2e06d5d7096f826014916b4cc02269fc.diff

LOG: [mlir] Support memref layout maps in vector transfer ops

Differential Revision: https://reviews.llvm.org/D102042

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
index 7a016bd88547..7b976cc3c2a5 100644
--- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
@@ -1013,6 +1013,14 @@ struct Strategy1d<TransferWriteOp> {
   }
 };
 
+/// Return true if the last dimension of the MemRefType has unit stride.
+static bool isLastMemrefDimUnitStride(MemRefType type) {
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  auto successStrides = getStridesAndOffset(type, strides, offset);
+  return succeeded(successStrides) && strides.back() == 1;
+}
+
 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
 /// necessary in cases where a 1D vector transfer op cannot be lowered into
 /// vector load/stores due to non-unit strides or broadcasts:
@@ -1052,11 +1060,14 @@ struct TransferOp1dConversion : public OpRewritePattern<OpTy> {
                                 PatternRewriter &rewriter) const override {
     ScopedContext scope(rewriter, xferOp.getLoc());
     auto map = xferOp.permutation_map();
+    auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
 
+    if (!memRefType)
+      return failure();
     if (xferOp.getVectorType().getRank() != 1)
-        return failure();
-    if (map.isMinorIdentity())  // Handled by ConvertVectorToLLVM
-        return failure();
+      return failure();
+    if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
+      return failure(); // Handled by ConvertVectorToLLVM
 
     // Loop bounds, step, state...
     auto vecType = xferOp.getVectorType();


        


More information about the Mlir-commits mailing list