[Mlir-commits] [mlir] 2bf491c - [mlir][VectorOps] Fail fast when a strided memref is passed to vector_transfer

Benjamin Kramer llvmlistbot at llvm.org
Wed Sep 2 01:35:42 PDT 2020


Author: Benjamin Kramer
Date: 2020-09-02T10:34:36+02:00
New Revision: 2bf491c7294c020d1754cddbf3a55e8e21c14bdc

URL: https://github.com/llvm/llvm-project/commit/2bf491c7294c020d1754cddbf3a55e8e21c14bdc
DIFF: https://github.com/llvm/llvm-project/commit/2bf491c7294c020d1754cddbf3a55e8e21c14bdc.diff

LOG: [mlir][VectorOps] Fail fast when a strided memref is passed to vector_transfer

Otherwise we'll silently miscompile things.

Differential Revision: https://reviews.llvm.org/D86951

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 23af60be585c..ecb047a1ad14 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1025,6 +1025,25 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
   bool hasBoundedRewriteRecursion() const final { return true; }
 };
 
+/// Returns true if the memory underlying `memRefType` has a contiguous layout.
+/// Strides are written to `strides`.
+static bool isContiguous(MemRefType memRefType,
+                         SmallVectorImpl<int64_t> &strides) {
+  int64_t offset;
+  auto successStrides = getStridesAndOffset(memRefType, strides, offset);
+  bool isContiguous = (strides.back() == 1);
+  if (isContiguous) {
+    auto sizes = memRefType.getShape();
+    for (int index = 0, e = strides.size() - 2; index < e; ++index) {
+      if (strides[index] != strides[index + 1] * sizes[index + 1]) {
+        isContiguous = false;
+        break;
+      }
+    }
+  }
+  return succeeded(successStrides) && isContiguous;
+}
+
 class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
 public:
   explicit VectorTypeCastOpConversion(MLIRContext *context,
@@ -1058,22 +1077,9 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
       return failure();
 
-    int64_t offset;
-    SmallVector<int64_t, 4> strides;
-    auto successStrides =
-        getStridesAndOffset(sourceMemRefType, strides, offset);
-    bool isContiguous = (strides.back() == 1);
-    if (isContiguous) {
-      auto sizes = sourceMemRefType.getShape();
-      for (int index = 0, e = strides.size() - 2; index < e; ++index) {
-        if (strides[index] != strides[index + 1] * sizes[index + 1]) {
-          isContiguous = false;
-          break;
-        }
-      }
-    }
     // Only contiguous source tensors supported atm.
-    if (failed(successStrides) || !isContiguous)
+    SmallVector<int64_t, 4> strides;
+    if (!isContiguous(sourceMemRefType, strides))
       return failure();
 
     auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
@@ -1141,6 +1147,10 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
                                        xferOp.getVectorType().getRank(),
                                        op->getContext()))
       return failure();
+    // Only contiguous source tensors supported atm.
+    SmallVector<int64_t, 4> strides;
+    if (!isContiguous(xferOp.getMemRefType(), strides))
+      return failure();
 
     auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
 


        


More information about the Mlir-commits mailing list