[Mlir-commits] [mlir] [mlir][linalg] Fix Linalg runtime verification pass to handle tensors with dimensions of size 0 (PR #163791)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 20 14:06:52 PDT 2025


https://github.com/Hanumanth04 updated https://github.com/llvm/llvm-project/pull/163791

>From 6ce2814832a7e6b05d8ebcc629e180be25f82431 Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Wed, 15 Oct 2025 10:05:45 -0400
Subject: [PATCH 1/2] Fix Linalg runtime verification pass to handle empty
 tensors

---
 .../Transforms/RuntimeOpVerification.cpp      | 28 +++++++++++++++++++
 .../Linalg/CPU/runtime-verification.mlir      | 11 ++++++++
 2 files changed, 39 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 15eb51a6dcab2..87f21edc94c09 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
 
@@ -43,6 +44,32 @@ struct StructuredOpInterface
     auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
     auto one = arith::ConstantIndexOp::create(builder, loc, 1);
 
+    Value iterationDomainIsNonDegenerate;
+    for (auto [start, end] : llvm::zip(starts, ends)) {
+      auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
+      auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+
+      // Loop Trip count > 0 iff start < end
+      Value dimensionHasNonZeroTripCount = builder.create<index::CmpOp>(
+          loc, index::IndexCmpPredicate::SLT, startValue, endValue);
+
+      if (!iterationDomainIsNonDegenerate) {
+        iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
+      } else {
+        // Iteration domain is non-degenerate iff all dimensions have loop trip count
+        // > 0
+        iterationDomainIsNonDegenerate = builder.create<arith::AndIOp>(
+            loc, iterationDomainIsNonDegenerate, dimensionHasNonZeroTripCount);
+      }
+    }
+    
+    if (!iterationDomainIsNonDegenerate)
+      return;
+
+     auto ifOp = builder.create<scf::IfOp>(loc, iterationDomainIsNonDegenerate,
+                                        /*withElseRegion=*/false);
+     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
     // Subtract one from the loop ends before composing with the indexing map
     transform(ends, ends.begin(), [&](OpFoldResult end) {
       auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,6 +137,7 @@ struct StructuredOpInterface
         builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
       }
     }
+    builder.setInsertionPointAfter(ifOp);
   }
 };
 
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
index 9f4393efc87bf..0a2bdb9e8d68e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -103,6 +103,11 @@ func.func @main() {
   // CHECK: unexpected negative result on dimension #0 of input/output operand #0
   func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
 
+  %c0x = arith.constant dense<0.0> : tensor<0xf32>
+  %d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32>
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
+
   return
 }
 
@@ -297,3 +302,9 @@ func.func @reverse_from_3(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
   } -> tensor<?xf32>
   return %result : tensor<?xf32>
 }
+
+func.func @fill_empty_1d(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0.0 : f32
+  %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?xf32>) -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}

>From 43f73d73db00f2049424205b0a9ac39ddb3865bd Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Thu, 16 Oct 2025 10:05:02 -0400
Subject: [PATCH 2/2] Revert clang-format on .mlir file

---
 .../Linalg/Transforms/RuntimeOpVerification.cpp    | 12 ++++++------
 .../Dialect/Linalg/CPU/runtime-verification.mlir   | 14 +++++++++++++-
 2 files changed, 19 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 87f21edc94c09..181b4846835c0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -56,19 +56,19 @@ struct StructuredOpInterface
       if (!iterationDomainIsNonDegenerate) {
         iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
       } else {
-        // Iteration domain is non-degenerate iff all dimensions have loop trip count
-        // > 0
+        // Iteration domain is non-degenerate iff all dimensions have loop trip
+        // count > 0
         iterationDomainIsNonDegenerate = builder.create<arith::AndIOp>(
             loc, iterationDomainIsNonDegenerate, dimensionHasNonZeroTripCount);
       }
     }
-    
+
     if (!iterationDomainIsNonDegenerate)
       return;
 
-     auto ifOp = builder.create<scf::IfOp>(loc, iterationDomainIsNonDegenerate,
-                                        /*withElseRegion=*/false);
-     builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+    auto ifOp = builder.create<scf::IfOp>(loc, iterationDomainIsNonDegenerate,
+                                          /*withElseRegion=*/false);
+    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
 
     // Subtract one from the loop ends before composing with the indexing map
     transform(ends, ends.begin(), [&](OpFoldResult end) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
index 0a2bdb9e8d68e..127ab70cb4539 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -103,11 +103,17 @@ func.func @main() {
   // CHECK: unexpected negative result on dimension #0 of input/output operand #0
   func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
 
-  %c0x = arith.constant dense<0.0> : tensor<0xf32>
+  %c0x = arith.constant dense<1.0> : tensor<0xf32>
   %d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32>
   // CHECK-NOT: ERROR: Runtime op verification failed
   func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
 
+  %c0x5 = arith.constant dense<0.0> : tensor<0x5xf32>
+  %d0x5 = tensor.cast %c0x5 : tensor<0x5xf32> to tensor<?x?xf32>
+
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @fill_empty_2d(%d0x5) : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
   return
 }
 
@@ -308,3 +314,9 @@ func.func @fill_empty_1d(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
   %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?xf32>) -> tensor<?xf32>
   return %0 : tensor<?xf32>
 }
+
+func.func @fill_empty_2d(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %c0 = arith.constant 0.0 : f32
+  %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}



More information about the Mlir-commits mailing list