[Mlir-commits] [mlir] [mlir][DialectUtils] Add helper for matching zero int/float values (PR #171293)

Matthias Springer llvmlistbot at llvm.org
Tue Dec 9 00:28:38 PST 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/171293

Add a convenience helper similar to `isZeroInteger` that works for integers and floats.


>From bfd97aa344dfd781f4548608c281a074ea61d2fc Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 9 Dec 2025 08:26:41 +0000
Subject: [PATCH] [mlir][DialectUtils] Add helper for matching zero int/float
 values

---
 mlir/include/mlir/Dialect/Utils/StaticValueUtils.h |  9 ++++++++-
 .../Dialect/Linalg/Transforms/FoldAddIntoDest.cpp  |  2 +-
 .../Transforms/SparseTensorRewriting.cpp           | 13 ++++---------
 mlir/lib/Dialect/Utils/StaticValueUtils.cpp        | 14 ++++++++++++++
 4 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 2e7f85cce4654..ba8a0304de9d3 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -24,9 +24,16 @@
 
 namespace mlir {
 
-/// Return true if `v` is an IntegerAttr with value `0`.
+/// Return "true" if `v` is an integer value/attribute with constant value `0`.
 bool isZeroInteger(OpFoldResult v);
 
+/// Return "true" if `v` is a float value/attribute with constant value `0.0`.
+bool isZeroFloat(OpFoldResult v);
+
+/// Return "true" if `v` is an integer/float value/attribute with constant
+/// value zero.
+bool isZeroIntegerOrFloat(OpFoldResult v);
+
 /// Return true if `v` is an IntegerAttr with value `1`.
 bool isOneInteger(OpFoldResult v);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
index e940b0787043e..6f81702ee22c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FoldAddIntoDest.cpp
@@ -21,7 +21,7 @@ static bool isDefinedAsZero(Value val) {
 
   // Check whether val is a constant scalar / vector splat / tensor splat float
   // or integer zero.
-  if (matchPattern(val, m_AnyZeroFloat()) || matchPattern(val, m_Zero()))
+  if (isZeroIntegerOrFloat(val))
     return true;
 
   return TypeSwitch<Operation *, bool>(val.getDefiningOp())
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 79f4e7f67ab9d..24290bde62f49 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -39,11 +39,6 @@ using namespace mlir::sparse_tensor;
 // Helper methods for the actual rewriting rules.
 //===---------------------------------------------------------------------===//
 
-// Helper method to match any typed zero.
-static bool isZeroValue(Value val) {
-  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat());
-}
-
 // Helper to detect a sparse tensor type operand.
 static bool isSparseTensor(Value v) {
   auto enc = getSparseTensorEncoding(v.getType());
@@ -59,14 +54,14 @@ static bool isMaterializing(OpOperand *op, bool isZero) {
   if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
     Value copy = alloc.getCopy();
     if (isZero)
-      return copy && isZeroValue(copy);
+      return copy && isZeroIntegerOrFloat(copy);
     return !copy;
   }
   // Check for empty tensor materialization.
   if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
     return !isZero;
   // Last resort for zero alloc: the whole value is zero.
-  return isZero && isZeroValue(val);
+  return isZero && isZeroIntegerOrFloat(val);
 }
 
 // Helper to detect sampling operation.
@@ -114,10 +109,10 @@ static bool isZeroYield(GenericOp op) {
   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
   if (auto arg = dyn_cast<BlockArgument>(yieldOp.getOperand(0))) {
     if (arg.getOwner()->getParentOp() == op) {
-      return isZeroValue(op->getOperand(arg.getArgNumber()));
+      return isZeroIntegerOrFloat(op->getOperand(arg.getArgNumber()));
     }
   }
-  return isZeroValue(yieldOp.getOperand(0));
+  return isZeroIntegerOrFloat(yieldOp.getOperand(0));
 }
 
 /// Populates given sizes array from type (for static sizes) and from
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 8d3944f883963..089c551c1612b 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -19,6 +19,20 @@ namespace mlir {
 
 bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(v, 0); }
 
+bool isZeroFloat(OpFoldResult v) {
+  if (auto attr = dyn_cast<Attribute>(v)) {
+    if (auto floatAttr = dyn_cast<FloatAttr>(attr))
+      return floatAttr.getValue().isZero();
+  } else {
+    return matchPattern(cast<Value>(v), m_AnyZeroFloat());
+  }
+  return false;
+}
+
+bool isZeroIntegerOrFloat(OpFoldResult v) {
+  return isZeroInteger(v) || isZeroFloat(v);
+}
+
 bool isOneInteger(OpFoldResult v) { return isConstantIntValue(v, 1); }
 
 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,



More information about the Mlir-commits mailing list