[Mlir-commits] [mlir] cbe7c49 - [mlir][memref] Fix runtime verification for memref.subview when size dimension value is 0 (#164897)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Mon Oct 27 11:43:48 PDT 2025
    
    
  
Author: Hanumanth
Date: 2025-10-27T11:43:45-07:00
New Revision: cbe7c49e93b630d3388dba2663b08a3c5c1bc8b6
URL: https://github.com/llvm/llvm-project/commit/cbe7c49e93b630d3388dba2663b08a3c5c1bc8b6
DIFF: https://github.com/llvm/llvm-project/commit/cbe7c49e93b630d3388dba2663b08a3c5c1bc8b6.diff
LOG: [mlir][memref] Fix runtime verification for memref.subview when size dimension value is 0 (#164897)
Previously, the runtime verification pass would insert assertion
statements with conditions that always evaluate to false for
semantically valid `memref.subview` operations where one of the
dimensions had a size of 0.
The `memref.subview` runtime verification logic was unconditionally
generating checks for the position of the last element (`offset + (size
- 1) * stride`). When `size` is 0, this causes the assertion condition
to always be false, leading to runtime failures even though the
operation is semantically valid.
This patch fixes the issue by making the `lastPos` check conditional.
The offset is always verified, but the endpoint check is only performed
when `size > 0` to avoid generating spurious assert statements.
This issue was discovered through a LiteRT model, where a dynamic shape
calculation resulted in a zero-sized dimension being passed to
`memref.subview`. The following is a simplified IR snippet from the
model. After running the runtime verification pass, an assertion that
always fails is generated because the SSA value `%5` becomes 0.
```mlir
module {
  memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<-1> {alignment = 64 : i64}
  memref.global "private" constant @__constant_1xi32 : memref<1xi32> = dense<0> {alignment = 64 : i64}
  func.func @simpleRepro(%arg0: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
    %c2 = arith.constant 2 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    %c10 = arith.constant 10 : index
    %c0 = arith.constant 0 : index
    %c-1 = arith.constant -1 : index
    %0 = memref.get_global @__constant_1xi32 : memref<1xi32>
    %1 = memref.get_global @__constant_2xi32 : memref<2xi32>
    %alloca = memref.alloca() {alignment = 64 : i64} : memref<3xi32>
    %subview = memref.subview %alloca[0] [1] [1] : memref<3xi32> to memref<1xi32, strided<[1]>>
    memref.copy %0, %subview : memref<1xi32> to memref<1xi32, strided<[1]>>
    %subview_0 = memref.subview %alloca[1] [2] [1] : memref<3xi32> to memref<2xi32, strided<[1], offset: 1>>
    memref.copy %1, %subview_0 : memref<2xi32> to memref<2xi32, strided<[1], offset: 1>>
    %2 = memref.load %alloca[%c0] : memref<3xi32>
    %3 = index.casts %2 : i32 to index
    %4 = arith.cmpi eq, %3, %c-1 : index
    %5 = arith.select %4, %c10, %3 : index
    %6 = memref.load %alloca[%c1] : memref<3xi32>
    %7 = index.casts %6 : i32 to index
    %8 = arith.cmpi eq, %7, %c-1 : index
    %9 = arith.select %8, %c4, %7 : index
    %10 = memref.load %alloca[%c2] : memref<3xi32>
    %11 = index.casts %10 : i32 to index
    %12 = arith.cmpi eq, %11, %c-1 : index
    %13 = arith.select %12, %c1, %11 : index
    %subview_1 = memref.subview %arg0[0, 0, 0] [%5, %9, %13] [1, 1, 1] : memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
    return %subview_1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
  }
}
```
P.S. This is a similar issue to the one fixed for `tensor.extract_slice`
in https://github.com/llvm/llvm-project/pull/164878
---------
Co-authored-by: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Added: 
    
Modified: 
    mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
    mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
Removed: 
    
################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 291da1f76ca9b..14152c5a1af0c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
 
 using namespace mlir;
@@ -273,7 +274,9 @@ struct SubViewOpInterface
     Value one = arith::ConstantIndexOp::create(builder, loc, 1);
     auto metadataOp =
         ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
-    for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+    for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+      // Reset insertion point to before the operation for each dimension
+      builder.setInsertionPoint(subView);
       Value offset = getValueOrCreateConstantIndexOp(
           builder, loc, subView.getMixedOffsets()[i]);
       Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,16 @@ struct SubViewOpInterface
                                                         std::to_string(i) +
                                                         " is out-of-bounds"));
 
+      // Only verify if size > 0
+      Value sizeIsNonZero = arith::CmpIOp::create(
+          builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+      auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
+                                    sizeIsNonZero, /*withElseRegion=*/true);
+
+      // Populate the "then" region (for size > 0).
+      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
       // Verify that slice does not run out-of-bounds.
       Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
       Value sizeMinusOneTimesStride =
@@ -298,8 +311,20 @@ struct SubViewOpInterface
           arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
       Value lastPosInBounds =
           generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
+
+      scf::YieldOp::create(builder, loc, lastPosInBounds);
+
+      // Populate the "else" region (for size == 0).
+      builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+      Value trueVal =
+          arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
+      scf::YieldOp::create(builder, loc, trueVal);
+
+      builder.setInsertionPointAfter(ifOp);
+      Value finalCondition = ifOp.getResult(0);
+
       cf::AssertOp::create(
-          builder, loc, lastPosInBounds,
+          builder, loc, finalCondition,
           generateErrorMessage(op,
                                "subview runs out-of-bounds along dimension " +
                                    std::to_string(i)));
diff  --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 71e813c0a6300..84875675ac3d0 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -2,6 +2,7 @@
 // RUN:     -expand-strided-metadata \
 // RUN:     -lower-affine \
 // RUN:     -test-cf-assert \
+// RUN:     -convert-scf-to-cf \
 // RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
@@ -11,6 +12,7 @@
 // RUN:     -expand-strided-metadata \
 // RUN:     -lower-affine \
 // RUN:     -test-cf-assert \
+// RUN:     -convert-scf-to-cf \
 // RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
 // RUN:     -reconcile-unrealized-casts | \
 // RUN: mlir-runner -e main -entry-point-result=void \
@@ -38,6 +40,17 @@ func.func @subview_dynamic_rank_reduce(%memref: memref<?x4xf32>, %offset: index,
     return
 }
 
+func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, 
+                                 %dim_0: index, 
+                                 %dim_1: index, 
+                                 %dim_2: index) {
+    %subview = memref.subview %memref[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
+        memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
+        memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+    return
+}
+
+
 func.func @main() {
   %0 = arith.constant 0 : index
   %1 = arith.constant 1 : index
@@ -105,6 +118,14 @@ func.func @main() {
   // CHECK-NOT: ERROR: Runtime op verification failed
   func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (memref<?x4xf32>, index, index, index) -> ()
 
+  %alloca_10x4x1 = memref.alloca() : memref<10x4x1xf32>
+  %alloca_10x4x1_dyn_stride = memref.cast %alloca_10x4x1 : memref<10x4x1xf32> to memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  %dim_0 = arith.constant 0 : index
+  %dim_1 = arith.constant 4 : index
+  %dim_2 = arith.constant 1 : index
+  func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
+                                        : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()
 
   return
 }
        
    
    
More information about the Mlir-commits
mailing list