[Mlir-commits] [mlir] [mlir][tensor] Fix runtime verification for `tensor.extract_slice` when size dimension value is 0 (PR #164878)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 23 14:28:16 PDT 2025


https://github.com/Hanumanth04 updated https://github.com/llvm/llvm-project/pull/164878

>From 5715e8f0f2f77ce798b61353a82901ab66c77d64 Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Thu, 23 Oct 2025 14:39:16 -0400
Subject: [PATCH 1/3] [mlir][Tensor] Fix Tensor runtime verification pass to
 handle tensors with dimensions of size 0

---
 .../Transforms/RuntimeOpVerification.cpp      | 45 ++++++++++++++++++-
 .../extract_slice-runtime-verification.mlir   | 13 ++++++
 2 files changed, 57 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index c031118606823..f53fe3a7f36cb 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -12,6 +12,8 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
 
@@ -158,7 +160,11 @@ struct ExtractSliceOpInterface
     // 0 <= offset + (size - 1) * stride < dim_size
     Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
     Value one = arith::ConstantIndexOp::create(builder, loc, 1);
-    for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+
+    for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+      // Reset insertion point to before the operation for each dimension
+      builder.setInsertionPoint(extractSliceOp);
+
       Value offset = getValueOrCreateConstantIndexOp(
           builder, loc, extractSliceOp.getMixedOffsets()[i]);
       Value size = getValueOrCreateConstantIndexOp(
@@ -176,6 +182,42 @@ struct ExtractSliceOpInterface
                                                         std::to_string(i) +
                                                         " is out-of-bounds"));
 
+      // Only verify if size > 0
+      Value sizeIsNonZero = arith::CmpIOp::create(
+          builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+      /*
+       * Split the current block to create the below control flow structure:
+       *
+       * ^preCondBlock:
+       *   ... // offset check already done above
+       *   %size_nonzero = arith.cmpi sgt, %size, %zero
+       *   cf.cond_br %size_nonzero, ^sizeBoundsCheckBlock, ^afterCheckBlock
+       *
+       * ^sizeBoundsCheckBlock:
+       *   %last_pos = ... // compute offset + (size-1) * stride
+       *   %last_pos_ok = ... // last position bounds check
+       *   cf.assert %last_pos_ok, "extract_slice runs out-of-bounds"
+       *   cf.br ^afterCheckBlock
+       *
+       * ^afterCheckBlock:
+       *   tensor.extract_slice ... // the original operation
+       */
+      Block *preCondBlock = builder.getBlock();
+      Block *afterCheckBlock = preCondBlock->splitBlock(extractSliceOp);
+
+      // Create the block for conditional size bounds verification.
+      Block *sizeBoundsCheckBlock = builder.createBlock(
+          preCondBlock->getParent(), Region::iterator(afterCheckBlock));
+
+      // Terminate the pre-condition block with the conditional branch.
+      builder.setInsertionPointToEnd(preCondBlock);
+      cf::CondBranchOp::create(builder, loc, sizeIsNonZero,
+                               sizeBoundsCheckBlock, afterCheckBlock);
+
+      // Populate the size bounds check block with lastPos verification.
+      builder.setInsertionPointToStart(sizeBoundsCheckBlock);
+
       // Verify that slice does not run out-of-bounds.
       Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
       Value sizeMinusOneTimesStride =
@@ -189,6 +231,7 @@ struct ExtractSliceOpInterface
           generateErrorMessage(
               op, "extract_slice runs out-of-bounds along dimension " +
                       std::to_string(i)));
+      cf::BranchOp::create(builder, loc, afterCheckBlock);
     }
   }
 };
diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
index 0c7c4a6cb2d6f..558d5351d8dc7 100644
--- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
@@ -34,6 +34,12 @@ func.func @extract_slice_dynamic_rank_reduce(%tensor: tensor<?x4xf32>, %offset:
     return
 }
 
+func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index) {
+    tensor.extract_slice %arg0[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
+    return
+}
+
+
 func.func @main() {
   %0 = arith.constant 0 : index
   %1 = arith.constant 1 : index
@@ -101,6 +107,13 @@ func.func @main() {
   // CHECK-NOT: ERROR: Runtime op verification failed
   func.call @extract_slice_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (tensor<?x4xf32>, index, index, index) -> ()
 
+  %alloca_10 = arith.constant dense<1.0> : tensor<10x4x1xf32>
+  
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  %dim_0 = arith.constant 0 : index
+  %dim_1 = arith.constant 4 : index
+  %dim_2 = arith.constant 1 : index
+  func.call @extract_slice_zero_size_dim(%alloca_10, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()
 
   return
 }

>From a74491cd5a3f31df2174cdf79efa40edac418c3d Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Thu, 23 Oct 2025 15:40:10 -0400
Subject: [PATCH 2/3] Use more descriptive name for test constant

---
 .../Dialect/Tensor/extract_slice-runtime-verification.mlir    | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
index 558d5351d8dc7..a77fa310a3699 100644
--- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
@@ -107,13 +107,13 @@ func.func @main() {
   // CHECK-NOT: ERROR: Runtime op verification failed
   func.call @extract_slice_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (tensor<?x4xf32>, index, index, index) -> ()
 
-  %alloca_10 = arith.constant dense<1.0> : tensor<10x4x1xf32>
+  %cst10x4x1xf32 = arith.constant dense<1.0> : tensor<10x4x1xf32>
   
   // CHECK-NOT: ERROR: Runtime op verification failed
   %dim_0 = arith.constant 0 : index
   %dim_1 = arith.constant 4 : index
   %dim_2 = arith.constant 1 : index
-  func.call @extract_slice_zero_size_dim(%alloca_10, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()
+  func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()
 
   return
 }

>From 02e5a91fe9f3cdd580b1aa16fbb856de3fb78a09 Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Thu, 23 Oct 2025 17:27:05 -0400
Subject: [PATCH 3/3] Remove redudant header includes

---
 mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index f53fe3a7f36cb..346aa6b4eb73d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -12,8 +12,6 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
 



More information about the Mlir-commits mailing list