[Mlir-commits] [mlir] [mlir][linalg] Fix Linalg runtime verification pass to handle tensors with dimensions of size 0 (PR #163791)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 20 14:06:52 PDT 2025
https://github.com/Hanumanth04 updated https://github.com/llvm/llvm-project/pull/163791
>From 6ce2814832a7e6b05d8ebcc629e180be25f82431 Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Wed, 15 Oct 2025 10:05:45 -0400
Subject: [PATCH 1/2] Fix Linalg runtime verification pass to handle empty
tensors
---
.../Transforms/RuntimeOpVerification.cpp | 28 +++++++++++++++++++
.../Linalg/CPU/runtime-verification.mlir | 11 ++++++++
2 files changed, 39 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 15eb51a6dcab2..87f21edc94c09 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -43,6 +44,32 @@ struct StructuredOpInterface
auto zero = arith::ConstantIndexOp::create(builder, loc, 0);
auto one = arith::ConstantIndexOp::create(builder, loc, 1);
+ Value iterationDomainIsNonDegenerate;
+ for (auto [start, end] : llvm::zip(starts, ends)) {
+ auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start);
+ auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
+
+ // Loop Trip count > 0 iff start < end
+ Value dimensionHasNonZeroTripCount = builder.create<index::CmpOp>(
+ loc, index::IndexCmpPredicate::SLT, startValue, endValue);
+
+ if (!iterationDomainIsNonDegenerate) {
+ iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
+ } else {
+ // Iteration domain is non-degenerate iff all dimensions have loop trip count
+ // > 0
+ iterationDomainIsNonDegenerate = builder.create<arith::AndIOp>(
+ loc, iterationDomainIsNonDegenerate, dimensionHasNonZeroTripCount);
+ }
+ }
+
+ if (!iterationDomainIsNonDegenerate)
+ return;
+
+ auto ifOp = builder.create<scf::IfOp>(loc, iterationDomainIsNonDegenerate,
+ /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end);
@@ -110,6 +137,7 @@ struct StructuredOpInterface
builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg);
}
}
+ builder.setInsertionPointAfter(ifOp);
}
};
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
index 9f4393efc87bf..0a2bdb9e8d68e 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -103,6 +103,11 @@ func.func @main() {
// CHECK: unexpected negative result on dimension #0 of input/output operand #0
func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
+ %c0x = arith.constant dense<0.0> : tensor<0xf32>
+ %d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32>
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
+
return
}
@@ -297,3 +302,9 @@ func.func @reverse_from_3(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
} -> tensor<?xf32>
return %result : tensor<?xf32>
}
+
+func.func @fill_empty_1d(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
+ %c0 = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?xf32>) -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
>From 43f73d73db00f2049424205b0a9ac39ddb3865bd Mon Sep 17 00:00:00 2001
From: Hanumanth Hanumantharayappa <hhanuman at ah-hhanuman-l.dhcp.mathworks.com>
Date: Thu, 16 Oct 2025 10:05:02 -0400
Subject: [PATCH 2/2] Revert clang-format on .mlir file
---
.../Linalg/Transforms/RuntimeOpVerification.cpp | 12 ++++++------
.../Dialect/Linalg/CPU/runtime-verification.mlir | 14 +++++++++++++-
2 files changed, 19 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
index 87f21edc94c09..181b4846835c0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp
@@ -56,19 +56,19 @@ struct StructuredOpInterface
if (!iterationDomainIsNonDegenerate) {
iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount;
} else {
- // Iteration domain is non-degenerate iff all dimensions have loop trip count
- // > 0
+ // Iteration domain is non-degenerate iff all dimensions have loop trip
+ // count > 0
iterationDomainIsNonDegenerate = builder.create<arith::AndIOp>(
loc, iterationDomainIsNonDegenerate, dimensionHasNonZeroTripCount);
}
}
-
+
if (!iterationDomainIsNonDegenerate)
return;
- auto ifOp = builder.create<scf::IfOp>(loc, iterationDomainIsNonDegenerate,
- /*withElseRegion=*/false);
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ auto ifOp = builder.create<scf::IfOp>(loc, iterationDomainIsNonDegenerate,
+ /*withElseRegion=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
// Subtract one from the loop ends before composing with the indexing map
transform(ends, ends.begin(), [&](OpFoldResult end) {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
index 0a2bdb9e8d68e..127ab70cb4539 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir
@@ -103,11 +103,17 @@ func.func @main() {
// CHECK: unexpected negative result on dimension #0 of input/output operand #0
func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>)
- %c0x = arith.constant dense<0.0> : tensor<0xf32>
+ %c0x = arith.constant dense<1.0> : tensor<0xf32>
%d0x = tensor.cast %c0x : tensor<0xf32> to tensor<?xf32>
// CHECK-NOT: ERROR: Runtime op verification failed
func.call @fill_empty_1d(%d0x) : (tensor<?xf32>) -> (tensor<?xf32>)
+ %c0x5 = arith.constant dense<0.0> : tensor<0x5xf32>
+ %d0x5 = tensor.cast %c0x5 : tensor<0x5xf32> to tensor<?x?xf32>
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ func.call @fill_empty_2d(%d0x5) : (tensor<?x?xf32>) -> (tensor<?x?xf32>)
+
return
}
@@ -308,3 +314,9 @@ func.func @fill_empty_1d(%arg0: tensor<?xf32>) -> (tensor<?xf32>) {
%0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
+
+func.func @fill_empty_2d(%arg0: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %c0 = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
More information about the Mlir-commits
mailing list