[Mlir-commits] [mlir] [mlir][memref] Fix runtime verification for memref.subview when size dimension value is 0 (PR #164897)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 27 07:47:19 PDT 2025


https://github.com/Hanumanth04 updated https://github.com/llvm/llvm-project/pull/164897

>From 7a91711c9ee70b363f5ec40e6a587ce9026f1377 Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Thu, 23 Oct 2025 17:14:32 -0400
Subject: [PATCH 1/3] [mlir][memref] Fix runtime verification for
 memref.subview when size is 0

---
 .../Transforms/RuntimeOpVerification.cpp      | 42 ++++++++++++++++++-
 .../MemRef/subview-runtime-verification.mlir  | 19 +++++++++
 2 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 291da1f76ca9b..1979d5b7e6310 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -273,7 +274,9 @@ struct SubViewOpInterface
     Value one = arith::ConstantIndexOp::create(builder, loc, 1);
     auto metadataOp =
         ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
-    for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+    for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+      // Reset insertion point to before the operation for each dimension
+      builder.setInsertionPoint(subView);
       Value offset = getValueOrCreateConstantIndexOp(
           builder, loc, subView.getMixedOffsets()[i]);
       Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,42 @@ struct SubViewOpInterface
                                                         std::to_string(i) +
                                                         " is out-of-bounds"));
 
+      // Only verify if size > 0
+      Value sizeIsNonZero = arith::CmpIOp::create(
+          builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+      /*
+       * Split the current block to create the below control flow structure:
+       *
+       * ^preCondBlock:
+       *   ... // offset check already done above
+       *   %size_nonzero = arith.cmpi sgt, %size, %zero
+       *   cf.cond_br %size_nonzero, ^sizeBoundsCheckBlock, ^afterCheckBlock
+       *
+       * ^sizeBoundsCheckBlock:
+       *   %last_pos = ... // compute offset + (size-1) * stride
+       *   %last_pos_ok = ... // last position bounds check
+       *   cf.assert %last_pos_ok, "extract_slice runs out-of-bounds"
+       *   cf.br ^afterCheckBlock
+       *
+       * ^afterCheckBlock:
+       *   tensor.extract_slice ... // the original operation
+       */
+      Block *preCondBlock = builder.getBlock();
+      Block *afterCheckBlock = preCondBlock->splitBlock(subView);
+
+      // Create the block for conditional size bounds verification.
+      Block *sizeBoundsCheckBlock = builder.createBlock(
+          preCondBlock->getParent(), Region::iterator(afterCheckBlock));
+
+      // Terminate the pre-condition block with the conditional branch.
+      builder.setInsertionPointToEnd(preCondBlock);
+      cf::CondBranchOp::create(builder, loc, sizeIsNonZero,
+                               sizeBoundsCheckBlock, afterCheckBlock);
+
+      // Populate the size bounds check block with lastPos verification.
+      builder.setInsertionPointToStart(sizeBoundsCheckBlock);
+
       // Verify that slice does not run out-of-bounds.
       Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
       Value sizeMinusOneTimesStride =
@@ -303,6 +342,7 @@ struct SubViewOpInterface
           generateErrorMessage(op,
                                "subview runs out-of-bounds along dimension " +
                                    std::to_string(i)));
+      cf::BranchOp::create(builder, loc, afterCheckBlock);
     }
   }
 };
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 71e813c0a6300..001c435086976 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -38,6 +38,17 @@ func.func @subview_dynamic_rank_reduce(%memref: memref<?x4xf32>, %offset: index,
     return
 }
 
+func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, 
+                                 %dim_0: index, 
+                                 %dim_1: index, 
+                                 %dim_2: index) {
+    %subview = memref.subview %memref[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
+        memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
+        memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+    return
+}
+
+
 func.func @main() {
   %0 = arith.constant 0 : index
   %1 = arith.constant 1 : index
@@ -105,6 +116,14 @@ func.func @main() {
   // CHECK-NOT: ERROR: Runtime op verification failed
   func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (memref<?x4xf32>, index, index, index) -> ()
 
+  %alloca_10x4x1 = memref.alloca() : memref<10x4x1xf32>
+  %alloca_10x4x1_dyn_stride = memref.cast %alloca_10x4x1 : memref<10x4x1xf32> to memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  %dim_0 = arith.constant 0 : index
+  %dim_1 = arith.constant 4 : index
+  %dim_2 = arith.constant 1 : index
+  func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
+                                        : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()
 
   return
 }

>From 4b291a71b2a300c6ec3b3b0abc6441f3f1cd757c Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Thu, 23 Oct 2025 17:32:48 -0400
Subject: [PATCH 2/3] Remove redudant header includes

---
 mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 1979d5b7e6310..6a8bc1e9e2ac4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"

>From ed2b9ab78d097b93b601b8539991d0b7fd3193a2 Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Mon, 27 Oct 2025 10:46:38 -0400
Subject: [PATCH 3/3] Use SCF dialect ops for condition checking

---
 .../Transforms/RuntimeOpVerification.cpp      | 52 +++++++------------
 .../MemRef/subview-runtime-verification.mlir  |  2 +
 2 files changed, 21 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 6a8bc1e9e2ac4..14152c5a1af0c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
 
 using namespace mlir;
@@ -296,37 +297,11 @@ struct SubViewOpInterface
       Value sizeIsNonZero = arith::CmpIOp::create(
           builder, loc, arith::CmpIPredicate::sgt, size, zero);
 
-      /*
-       * Split the current block to create the below control flow structure:
-       *
-       * ^preCondBlock:
-       *   ... // offset check already done above
-       *   %size_nonzero = arith.cmpi sgt, %size, %zero
-       *   cf.cond_br %size_nonzero, ^sizeBoundsCheckBlock, ^afterCheckBlock
-       *
-       * ^sizeBoundsCheckBlock:
-       *   %last_pos = ... // compute offset + (size-1) * stride
-       *   %last_pos_ok = ... // last position bounds check
-       *   cf.assert %last_pos_ok, "extract_slice runs out-of-bounds"
-       *   cf.br ^afterCheckBlock
-       *
-       * ^afterCheckBlock:
-       *   tensor.extract_slice ... // the original operation
-       */
-      Block *preCondBlock = builder.getBlock();
-      Block *afterCheckBlock = preCondBlock->splitBlock(subView);
-
-      // Create the block for conditional size bounds verification.
-      Block *sizeBoundsCheckBlock = builder.createBlock(
-          preCondBlock->getParent(), Region::iterator(afterCheckBlock));
-
-      // Terminate the pre-condition block with the conditional branch.
-      builder.setInsertionPointToEnd(preCondBlock);
-      cf::CondBranchOp::create(builder, loc, sizeIsNonZero,
-                               sizeBoundsCheckBlock, afterCheckBlock);
-
-      // Populate the size bounds check block with lastPos verification.
-      builder.setInsertionPointToStart(sizeBoundsCheckBlock);
+      auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
+                                    sizeIsNonZero, /*withElseRegion=*/true);
+
+      // Populate the "then" region (for size > 0).
+      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
 
       // Verify that slice does not run out-of-bounds.
       Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
@@ -336,12 +311,23 @@ struct SubViewOpInterface
           arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
       Value lastPosInBounds =
           generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
+
+      scf::YieldOp::create(builder, loc, lastPosInBounds);
+
+      // Populate the "else" region (for size == 0).
+      builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+      Value trueVal =
+          arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
+      scf::YieldOp::create(builder, loc, trueVal);
+
+      builder.setInsertionPointAfter(ifOp);
+      Value finalCondition = ifOp.getResult(0);
+
       cf::AssertOp::create(
-          builder, loc, lastPosInBounds,
+          builder, loc, finalCondition,
           generateErrorMessage(op,
                                "subview runs out-of-bounds along dimension " +
                                    std::to_string(i)));
-      cf::BranchOp::create(builder, loc, afterCheckBlock);
     }
   }
 };
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 001c435086976..84875675ac3d0 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -2,6 +2,7 @@
 // RUN:     -expand-strided-metadata \
 // RUN:     -lower-affine \
 // RUN:     -test-cf-assert \
+// RUN:     -convert-scf-to-cf \
 // RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
@@ -11,6 +12,7 @@
 // RUN:     -expand-strided-metadata \
 // RUN:     -lower-affine \
 // RUN:     -test-cf-assert \
+// RUN:     -convert-scf-to-cf \
 // RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
 // RUN:     -reconcile-unrealized-casts | \
 // RUN: mlir-runner -e main -entry-point-result=void \



More information about the Mlir-commits mailing list