[Mlir-commits] [mlir] 919922b - [mlir] Added verification check for linalg.conv to ensure memrefs are of rank > 2

Alex Zinenko llvmlistbot at llvm.org
Thu Jul 23 03:27:12 PDT 2020


Author: Jakub Lichman
Date: 2020-07-23T12:27:05+02:00
New Revision: 919922b0c20e99de51bb0d06c98a5fc2fa5feec4

URL: https://github.com/llvm/llvm-project/commit/919922b0c20e99de51bb0d06c98a5fc2fa5feec4
DIFF: https://github.com/llvm/llvm-project/commit/919922b0c20e99de51bb0d06c98a5fc2fa5feec4.diff

LOG: [mlir] Added verification check for linalg.conv to ensure memrefs are of rank > 2

linalg.conv does not support memrefs with rank smaller than 3 as stated here:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/nn/convolution

However it does not verify it and thus crashes with "LLVM ERROR: out of memory"
error for 1D case and "nWin > 0 && "expected at least one window dimension"" assertion
for 2D case. This commit adds check for that in the verification method.

Differential Revision: https://reviews.llvm.org/D84317

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/test/Dialect/Linalg/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 32e0b12c22b5..4c68f0265677 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -995,6 +995,8 @@ static LogicalResult verify(ConvOp op) {
     return op.emitOpError("expects memref elemental types to match");
   if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank())
     return op.emitOpError("expects memref ranks to match");
+  if (oType.getRank() <= 2)
+    return op.emitOpError("expects memref ranks to be greater than 2");
   if (auto strides = op.strides()) {
     if (failed(
             verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))

diff  --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 99b942389e41..79d61d8d7892 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -428,6 +428,13 @@ func @generic_result_0_element_type(%arg0: memref<?xf32>) {
 
 // -----
 
+func @conv_rank_limit(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>) {
+  // expected-error @+1 {{expects memref ranks to be greater than 2}}
+  linalg.conv(%arg0, %arg1, %arg2) : memref<?xf32>, memref<?xf32>, memref<?xf32>
+}
+
+// -----
+
 // expected-error @+1 {{unknown Linalg type}}
 !invalid_type = type !linalg.unknown
 


        


More information about the Mlir-commits mailing list