[Mlir-commits] [mlir] [mlir][memref] Fix out-of-bounds crash when reifying result dims (PR #70774)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 30 23:46:56 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Do not crash when the input IR is invalid, i.e., when the index of the dimension operand of a `tensor.dim`/`memref.dim` is out-of-bounds. This fixes #<!-- -->70180.
---
Full diff: https://github.com/llvm/llvm-project/pull/70774.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp (+3)
- (added) mlir/test/Dialect/MemRef/resolve-dim-ops.mlir (+27)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index f18ae2cc9b68816..8e3b35e2acba963 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -94,6 +94,9 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
reifiedResultShapes)))
return failure();
unsigned resultNumber = dimValue.getResultNumber();
+ // Do not apply pattern if the IR is invalid (dim out of bounds).
+ if (*dimIndex >= reifiedResultShapes[resultNumber].size())
+ return failure();
Value replacement = getValueOrCreateConstantIndexOp(
rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
rewriter.replaceOp(dimOp, replacement);
diff --git a/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
new file mode 100644
index 000000000000000..18e9a9d02e10819
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt --resolve-ranked-shaped-type-result-dims --split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func @dim_out_of_bounds(
+// CHECK-NEXT: arith.constant
+// CHECK-NEXT: memref.dim
+// CHECK-NEXT: return
+func.func @dim_out_of_bounds(%m : memref<7x8xf32>) -> index {
+ %idx = arith.constant 7 : index
+ %0 = memref.dim %m, %idx : memref<7x8xf32>
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @dim_out_of_bounds_2(
+// CHECK-NEXT: arith.constant
+// CHECK-NEXT: arith.constant
+// CHECK-NEXT: bufferization.alloc_tensor
+// CHECK-NEXT: tensor.dim
+// CHECK-NEXT: return
+func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
+ %idx = arith.constant 7 : index
+ %sz = arith.constant 5 : index
+ %alloc = bufferization.alloc_tensor(%sz, %sz) : tensor<?x?xf32>
+ %0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
+ return %0 : index
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/70774
More information about the Mlir-commits
mailing list