[Mlir-commits] [mlir] [mlir][memref] Fix out-of-bounds crash when reifying result dims (PR #70774)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 30 23:46:56 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Do not crash when the input IR is invalid, i.e., when the index of the dimension operand of a `tensor.dim`/`memref.dim` is out-of-bounds. This fixes #<!-- -->70180.

---
Full diff: https://github.com/llvm/llvm-project/pull/70774.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+3) 
- (added) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+27) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index f18ae2cc9b68816..8e3b35e2acba963 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -94,6 +94,9 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
                                  reifiedResultShapes)))
       return failure();
     unsigned resultNumber = dimValue.getResultNumber();
+    // Do not apply pattern if the IR is invalid (dim out of bounds).
+    if (*dimIndex >= reifiedResultShapes[resultNumber].size())
+      return failure();
     Value replacement = getValueOrCreateConstantIndexOp(
         rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
     rewriter.replaceOp(dimOp, replacement);
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
new file mode 100644
index 000000000000000..18e9a9d02e10819
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --resolve-ranked-shaped-type-result-dims --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @dim_out_of_bounds(
+//  CHECK-NEXT:   arith.constant
+//  CHECK-NEXT:   memref.dim
+//  CHECK-NEXT:   return
+func.func @dim_out_of_bounds(%m : memref<7x8xf32>) -> index {
+  %idx = arith.constant 7 : index
+  %0 = memref.dim %m, %idx : memref<7x8xf32>
+  return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @dim_out_of_bounds_2(
+//  CHECK-NEXT:   arith.constant
+//  CHECK-NEXT:   arith.constant
+//  CHECK-NEXT:   bufferization.alloc_tensor
+//  CHECK-NEXT:   tensor.dim
+//  CHECK-NEXT:   return
+func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
+  %idx = arith.constant 7 : index
+  %sz = arith.constant 5 : index
+  %alloc = bufferization.alloc_tensor(%sz, %sz) : tensor<?x?xf32>
+  %0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
+  return %0 : index
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/70774


More information about the Mlir-commits mailing list