[Mlir-commits] [mlir] 919922b - [mlir] Added verification check for linalg.conv to ensure memrefs are of rank > 2
Alex Zinenko
llvmlistbot at llvm.org
Thu Jul 23 03:27:12 PDT 2020
Author: Jakub Lichman
Date: 2020-07-23T12:27:05+02:00
New Revision: 919922b0c20e99de51bb0d06c98a5fc2fa5feec4
URL: https://github.com/llvm/llvm-project/commit/919922b0c20e99de51bb0d06c98a5fc2fa5feec4
DIFF: https://github.com/llvm/llvm-project/commit/919922b0c20e99de51bb0d06c98a5fc2fa5feec4.diff
LOG: [mlir] Added verification check for linalg.conv to ensure memrefs are of rank > 2
linalg.conv does not support memrefs with rank smaller than 3 as stated here:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution
However it does not verify it and thus crashes with "LLVM ERROR: out of memory"
error for 1D case and "nWin > 0 && "expected at least one window dimension"" assertion
for 2D case. This commit adds check for that in the verification method.
Differential Revision: https://reviews.llvm.org/D84317
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 32e0b12c22b5..4c68f0265677 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -995,6 +995,8 @@ static LogicalResult verify(ConvOp op) {
return op.emitOpError("expects memref elemental types to match");
if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
return op.emitOpError("expects memref ranks to match");
+ if (oType.getRank() <= 2)
+ return op.emitOpError("expects memref ranks to be greater than 2");
if (auto strides = op.strides()) {
if (failed(
verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 99b942389e41..79d61d8d7892 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -428,6 +428,13 @@ func @generic_result_0_element_type(%arg0: memref<?xf32>) {
// -----
+func @conv_rank_limit(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+ // expected-error @+1 {{expects memref ranks to be greater than 2}}
+ linalg.conv(%arg0, %arg1, %arg2) : memref<?xf32>, memref<?xf32>, memref<?xf32>
+}
+
+// -----
+
// expected-error @+1 {{unknown Linalg type}}
!invalid_type = type !linalg.unknown
More information about the Mlir-commits
mailing list