[Mlir-commits] [mlir] [mlir][memref] Add terminator check to prevent a crash (PR #141972)

Longsheng Mou llvmlistbot at llvm.org
Thu May 29 09:13:36 PDT 2025


https://github.com/CoTinker created https://github.com/llvm/llvm-project/pull/141972

This PR adds terminator check to prevent a crash when invoke `lastNonTerminatorInRegion`. Fixes #137333.

>From a3e6848e0f35de5b03e25c0f2911002cb638e721 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 30 May 2025 00:09:26 +0800
Subject: [PATCH] [mlir][memref] Add terminator check to prevent a crash

This PR adds terminator check to prevent a crash when invoke
`lastNonTerminatorInRegion`.
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   |  7 +++--
 mlir/test/Dialect/MemRef/canonicalize.mlir | 35 ++++++++++++++++++++++
 2 files changed, 39 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index cab0ab8d15d5d..aa9587510670c 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -398,8 +398,9 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
 /// and is only followed by a terminator. This prevents
 /// extending the lifetime of allocations.
 static bool lastNonTerminatorInRegion(Operation *op) {
-  return op->getNextNode() == op->getBlock()->getTerminator() &&
-         llvm::hasSingleElement(op->getParentRegion()->getBlocks());
+  return op->getBlock()->mightHaveTerminator() &&
+         op->getNextNode() == op->getBlock()->getTerminator() &&
+         op->getParentRegion()->hasOneBlock();
 }
 
 /// Inline an AllocaScopeOp if either the direct parent is an allocation scope
@@ -2011,7 +2012,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
       // Second, check the sizes.
       if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
                        op.getConstifiedMixedSizes()))
-          return false;
+        return false;
 
       // Finally, check the offset.
       assert(op.getMixedOffsets().size() == 1 &&
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e7cee7cd85426..6f17caad3fd6a 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -739,6 +739,8 @@ func.func @scopeMerge() {
 // CHECK:     "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
 // CHECK:     return
 
+// -----
+
 func.func @scopeMerge2() {
   "test.region"() ({
     memref.alloca_scope {
@@ -763,6 +765,8 @@ func.func @scopeMerge2() {
 // CHECK:     return
 // CHECK:   }
 
+// -----
+
 func.func @scopeMerge3() {
   %cnt = "test.count"() : () -> index
   "test.region"() ({
@@ -787,6 +791,8 @@ func.func @scopeMerge3() {
 // CHECK:     return
 // CHECK:   }
 
+// -----
+
 func.func @scopeMerge4() {
   %cnt = "test.count"() : () -> index
   "test.region"() ({
@@ -813,6 +819,8 @@ func.func @scopeMerge4() {
 // CHECK:     return
 // CHECK:   }
 
+// -----
+
 func.func @scopeMerge5() {
   "test.region"() ({
     memref.alloca_scope {
@@ -839,6 +847,8 @@ func.func @scopeMerge5() {
 // CHECK:     return
 // CHECK:   }
 
+// -----
+
 func.func @scopeInline(%arg : memref<index>) {
   %cnt = "test.count"() : () -> index
   "test.region"() ({
@@ -855,6 +865,31 @@ func.func @scopeInline(%arg : memref<index>) {
 
 // -----
 
+// Ensure this case not crash.
+
+// CHECK-LABEL:   func.func @scope_merge_without_terminator() {
+// CHECK:           "test.region"() ({
+// CHECK:             memref.alloca_scope  {
+// CHECK:               %[[cnt:.*]] = "test.count"() : () -> index
+// CHECK:               %[[alloc:.*]] = memref.alloca(%[[cnt]]) : memref<?xi64>
+// CHECK:               "test.use"(%[[alloc]]) : (memref<?xi64>) -> ()
+// CHECK:             }
+// CHECK:           }) : () -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @scope_merge_without_terminator() {
+  "test.region"() ({
+    memref.alloca_scope {
+      %cnt = "test.count"() : () -> index
+      %a = memref.alloca(%cnt) : memref<?xi64>
+      "test.use"(%a) : (memref<?xi64>) -> ()
+    }
+  }) : () -> ()
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @reinterpret_noop
 //  CHECK-SAME: (%[[ARG:.*]]: memref<2x3x4xf32>)
 //  CHECK-NEXT: return %[[ARG]]



More information about the Mlir-commits mailing list