[Mlir-commits] [mlir] [mlir][tensor] Fix runtime verification for `tensor.extract_slice` when size is 0 (PR #164878)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 23 12:34:53 PDT 2025


https://github.com/Hanumanth04 created https://github.com/llvm/llvm-project/pull/164878

Previously, the runtime verification pass would insert assertion statements with conditions that always evaluate to false for semantically valid `tensor.extract_slice` operations where one of the dimensions had a size of 0.

The `tensor.extract_slice` runtime verification logic was unconditionally generating checks for the position of the last element (`offset + (size - 1) * stride`). When `size` is 0, this causes the assertion condition to always be false, leading to runtime failures even though the operation is semantically valid.

This patch fixes the issue by making the `lastPos` check conditional. The offset is always verified, but the endpoint check is only performed when `size > 0` to avoid generating spurious assert statements.

This issue was discovered through LiteRT model, where a dynamic shape calculation resulted in a zero-sized dimension being passed to `tensor.extract_slice`.

The following is a simplified IR snippet from the model. After running the runtime verification pass, an assertion that always fails is generated because the SSA value `%3` becomes 0.

```mlir
func.func @simple_repro_from_liteRT_model(%arg0: tensor<10x4x1xf32>) -> tensor<?x?x?xf32> {
  %cst = arith.constant dense<0> : tensor<1xi32>
  %cst_0 = arith.constant dense<-1> : tensor<2xi32>
  %c-1 = arith.constant -1 : index
  %c0 = arith.constant 0 : index
  %c10 = arith.constant 10 : index
  %c1 = arith.constant 1 : index
  %c4 = arith.constant 4 : index
  %c2 = arith.constant 2 : index
  %0 = tensor.empty() : tensor<3xi32>
  %inserted_slice = tensor.insert_slice %cst into %0[0] [1] [1] : tensor<1xi32> into tensor<3xi32>
  %inserted_slice_1 = tensor.insert_slice %cst_0 into %inserted_slice[1] [2] [1] : tensor<2xi32> into tensor<3xi32>
  %extracted = tensor.extract %inserted_slice_1[%c0] : tensor<3xi32>
  %1 = index.casts %extracted : i32 to index
  %2 = arith.cmpi eq, %1, %c-1 : index
  %3 = arith.select %2, %c10, %1 : index
  %extracted_2 = tensor.extract %inserted_slice_1[%c1] : tensor<3xi32>
  %4 = index.casts %extracted_2 : i32 to index
  %5 = arith.cmpi eq, %4, %c-1 : index
  %6 = arith.select %5, %c4, %4 : index
  %extracted_3 = tensor.extract %inserted_slice_1[%c2] : tensor<3xi32>
  %7 = index.casts %extracted_3 : i32 to index
  %8 = arith.cmpi eq, %7, %c-1 : index
  %9 = arith.select %8, %c1, %7 : index
  %extracted_slice = tensor.extract_slice %arg0[0, 0, 0] [%3, %6, %9] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
  return %extracted_slice : tensor<?x?x?xf32>
}
```

The issue can be reproduced more simply with the following test case, where `dim_0` is `0`. When the runtime verification pass is applied to this code with `dim_0 = 0`, it generates an assertion that will always fail at runtime.

```mlir
func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>,
                                      %dim_0: index,
                                      %dim_1: index,
                                      %dim_2: index) {
  %slice = tensor.extract_slice %arg0[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1]
    : tensor<10x4x1xf32> to tensor<?x?x?xf32>
  return
}

func.func @test_zero_size_extraction() {
  %input = arith.constant dense<1.0> : tensor<10x4x1xf32>
  // Define slice dimensions: 0x4x1 (zero-size in first dimension)
  %dim_0 = arith.constant 0 : index
  %dim_1 = arith.constant 4 : index
  %dim_2 = arith.constant 1 : index
  func.call @extract_slice_zero_size_dim(%input, %dim_0, %dim_1, %dim_2)
    : (tensor<10x4x1xf32>, index, index, index) -> ()
  return
}
```

P.S. We probably have a similar issue with `memref.subview`. I will check this and send a separate PR for the issue.

>From 5715e8f0f2f77ce798b61353a82901ab66c77d64 Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Thu, 23 Oct 2025 14:39:16 -0400
Subject: [PATCH] [mlir][Tensor] Fix Tensor runtime verification pass to handle
 tensors with dimensions of size 0

---
 .../Transforms/RuntimeOpVerification.cpp      | 45 ++++++++++++++++++-
 .../extract_slice-runtime-verification.mlir   | 13 ++++++
 2 files changed, 57 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index c031118606823..f53fe3a7f36cb 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -12,6 +12,8 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
 
@@ -158,7 +160,11 @@ struct ExtractSliceOpInterface
     // 0 <= offset + (size - 1) * stride < dim_size
     Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
     Value one = arith::ConstantIndexOp::create(builder, loc, 1);
-    for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+
+    for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+      // Reset insertion point to before the operation for each dimension
+      builder.setInsertionPoint(extractSliceOp);
+
       Value offset = getValueOrCreateConstantIndexOp(
           builder, loc, extractSliceOp.getMixedOffsets()[i]);
       Value size = getValueOrCreateConstantIndexOp(
@@ -176,6 +182,42 @@ struct ExtractSliceOpInterface
                                                         std::to_string(i) +
                                                         " is out-of-bounds"));
 
+      // Only verify if size > 0
+      Value sizeIsNonZero = arith::CmpIOp::create(
+          builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+      /*
+       * Split the current block to create the below control flow structure:
+       *
+       * ^preCondBlock:
+       *   ... // offset check already done above
+       *   %size_nonzero = arith.cmpi sgt, %size, %zero
+       *   cf.cond_br %size_nonzero, ^sizeBoundsCheckBlock, ^afterCheckBlock
+       *
+       * ^sizeBoundsCheckBlock:
+       *   %last_pos = ... // compute offset + (size-1) * stride
+       *   %last_pos_ok = ... // last position bounds check
+       *   cf.assert %last_pos_ok, "extract_slice runs out-of-bounds"
+       *   cf.br ^afterCheckBlock
+       *
+       * ^afterCheckBlock:
+       *   tensor.extract_slice ... // the original operation
+       */
+      Block *preCondBlock = builder.getBlock();
+      Block *afterCheckBlock = preCondBlock->splitBlock(extractSliceOp);
+
+      // Create the block for conditional size bounds verification.
+      Block *sizeBoundsCheckBlock = builder.createBlock(
+          preCondBlock->getParent(), Region::iterator(afterCheckBlock));
+
+      // Terminate the pre-condition block with the conditional branch.
+      builder.setInsertionPointToEnd(preCondBlock);
+      cf::CondBranchOp::create(builder, loc, sizeIsNonZero,
+                               sizeBoundsCheckBlock, afterCheckBlock);
+
+      // Populate the size bounds check block with lastPos verification.
+      builder.setInsertionPointToStart(sizeBoundsCheckBlock);
+
       // Verify that slice does not run out-of-bounds.
       Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
       Value sizeMinusOneTimesStride =
@@ -189,6 +231,7 @@ struct ExtractSliceOpInterface
           generateErrorMessage(
               op, "extract_slice runs out-of-bounds along dimension " +
                       std::to_string(i)));
+      cf::BranchOp::create(builder, loc, afterCheckBlock);
     }
   }
 };
diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
index 0c7c4a6cb2d6f..558d5351d8dc7 100644
--- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
@@ -34,6 +34,12 @@ func.func @extract_slice_dynamic_rank_reduce(%tensor: tensor<?x4xf32>, %offset:
     return
 }
 
+func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index) {
+    tensor.extract_slice %arg0[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
+    return
+}
+
+
 func.func @main() {
   %0 = arith.constant 0 : index
   %1 = arith.constant 1 : index
@@ -101,6 +107,13 @@ func.func @main() {
   // CHECK-NOT: ERROR: Runtime op verification failed
   func.call @extract_slice_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (tensor<?x4xf32>, index, index, index) -> ()
 
+  %alloca_10 = arith.constant dense<1.0> : tensor<10x4x1xf32>
+  
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  %dim_0 = arith.constant 0 : index
+  %dim_1 = arith.constant 4 : index
+  %dim_2 = arith.constant 1 : index
+  func.call @extract_slice_zero_size_dim(%alloca_10, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()
 
   return
 }



More information about the Mlir-commits mailing list