[Mlir-commits] [mlir] 2bf491c - [mlir][VectorOps] Fail fast when a strided memref is passed to vector_transfer
Benjamin Kramer
llvmlistbot at llvm.org
Wed Sep 2 01:35:42 PDT 2020
Author: Benjamin Kramer
Date: 2020-09-02T10:34:36+02:00
New Revision: 2bf491c7294c020d1754cddbf3a55e8e21c14bdc
URL: https://github.com/llvm/llvm-project/commit/2bf491c7294c020d1754cddbf3a55e8e21c14bdc
DIFF: https://github.com/llvm/llvm-project/commit/2bf491c7294c020d1754cddbf3a55e8e21c14bdc.diff
LOG: [mlir][VectorOps] Fail fast when a strided memref is passed to vector_transfer
Otherwise we'll silently miscompile things.
Differential Revision: https://reviews.llvm.org/D86951
Added:
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 23af60be585c..ecb047a1ad14 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1025,6 +1025,25 @@ class VectorInsertStridedSliceOpSameRankRewritePattern
bool hasBoundedRewriteRecursion() const final { return true; }
};
+/// Returns true if the memory underlying `memRefType` has a contiguous layout.
+/// Strides are written to `strides`.
+static bool isContiguous(MemRefType memRefType,
+ SmallVectorImpl<int64_t> &strides) {
+ int64_t offset;
+ auto successStrides = getStridesAndOffset(memRefType, strides, offset);
+ bool isContiguous = (strides.back() == 1);
+ if (isContiguous) {
+ auto sizes = memRefType.getShape();
+ for (int index = 0, e = strides.size() - 2; index < e; ++index) {
+ if (strides[index] != strides[index + 1] * sizes[index + 1]) {
+ isContiguous = false;
+ break;
+ }
+ }
+ }
+ return succeeded(successStrides) && isContiguous;
+}
+
class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorTypeCastOpConversion(MLIRContext *context,
@@ -1058,22 +1077,9 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- auto successStrides =
- getStridesAndOffset(sourceMemRefType, strides, offset);
- bool isContiguous = (strides.back() == 1);
- if (isContiguous) {
- auto sizes = sourceMemRefType.getShape();
- for (int index = 0, e = strides.size() - 2; index < e; ++index) {
- if (strides[index] != strides[index + 1] * sizes[index + 1]) {
- isContiguous = false;
- break;
- }
- }
- }
// Only contiguous source tensors supported atm.
- if (failed(successStrides) || !isContiguous)
+ SmallVector<int64_t, 4> strides;
+ if (!isContiguous(sourceMemRefType, strides))
return failure();
auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext());
@@ -1141,6 +1147,10 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
xferOp.getVectorType().getRank(),
op->getContext()))
return failure();
+ // Only contiguous source tensors supported atm.
+ SmallVector<int64_t, 4> strides;
+ if (!isContiguous(xferOp.getMemRefType(), strides))
+ return failure();
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
More information about the Mlir-commits
mailing list