[Mlir-commits] [mlir] [mlir][affine] Fix crash in AffineParallelLowering for unsupported reductions (PR #186189)

Mehdi Amini llvmlistbot at llvm.org
Thu Mar 12 11:56:48 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/186189

>From afc1beb1b99cb8b67f76bfcdf92e24b5c26de2cd Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 9 Mar 2026 04:36:36 -0700
Subject: [PATCH 1/2] [mlir][affine] Fix crash in AffineParallelLowering for
 unsupported reductions

When lowering affine.parallel with a reduction kind that has no identity
value (e.g. "assign"), getIdentityValueAttr() returns nullptr. The caller
getIdentityValue() then passed this null TypedAttr to
arith::ConstantOp::create(), triggering an LLVM_ERROR crash:
  "Failed to infer result type(s): arith.constant"

Two bugs are fixed:

1. getIdentityValue() now returns an empty Value{} when
   getIdentityValueAttr() returns nullptr, instead of unconditionally
   forwarding the null to ConstantOp::create().

2. AffineParallelLowering::matchAndRewrite() checks the returned Value
   and calls notifyMatchFailure() when the identity value is unavailable,
   allowing the pattern to fail gracefully rather than crashing.

3. AffineYieldOpLowering::matchAndRewrite() also returns failure when
   the parent op is affine.parallel (not just scf.parallel), preventing
   affine.yield from being prematurely lowered to scf.yield when the
   surrounding affine.parallel lowering failed.

Fixes #185250

Assisted-by: Claude Code
---
 .../AffineToStandard/AffineToStandard.cpp       | 10 +++++++---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp          |  2 ++
 .../AffineToStandard/lower-affine-invalid.mlir  | 17 +++++++++++++++++
 3 files changed, 26 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Conversion/AffineToStandard/lower-affine-invalid.mlir

diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index 3b148f9021666..826d7547716e4 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -134,7 +134,7 @@ class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
 
   LogicalResult matchAndRewrite(AffineYieldOp op,
                                 PatternRewriter &rewriter) const override {
-    if (isa<scf::ParallelOp>(op->getParentOp())) {
+    if (isa<scf::ParallelOp, AffineParallelOp>(op->getParentOp())) {
       // Terminator is rewritten as part of the "affine.parallel" lowering
       // pattern.
       return failure();
@@ -230,8 +230,12 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
               static_cast<uint64_t>(cast<IntegerAttr>(reduction).getInt()));
       assert(reductionOp && "Reduction operation cannot be of None Type");
       arith::AtomicRMWKind reductionOpValue = *reductionOp;
-      identityVals.push_back(
-          arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
+      Value identityVal =
+          arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc);
+      if (!identityVal)
+        return rewriter.notifyMatchFailure(
+            op, "unsupported reduction kind for identity value");
+      identityVals.push_back(identityVal);
     }
     parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
                                     upperBoundTuple, steps, identityVals,
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 6f368604df65a..d3e2e77d10341 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2801,6 +2801,8 @@ Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
                                     bool useOnlyFiniteValue) {
   auto attr =
       getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
+  if (!attr)
+    return {};
   return arith::ConstantOp::create(builder, loc, attr);
 }
 
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine-invalid.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine-invalid.mlir
new file mode 100644
index 0000000000000..a357b27a2430e
--- /dev/null
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine-invalid.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt --lower-affine %s 2>&1 | FileCheck %s
+
+// Test that affine.parallel with an unsupported reduction kind ("assign")
+// does not crash but emits a proper error message. Previously,
+// getIdentityValue would be called with a null TypedAttr and crash inside
+// arith::ConstantOp::build with "Failed to infer result type(s)".
+
+// CHECK: Reduction operation type not supported
+// CHECK-NOT: Failed to infer result type
+
+func.func @affine_parallel_assign_reduction_no_crash(%n: index) -> i32 {
+  %0 = affine.parallel (%i) = (0) to (%n) reduce ("assign") -> i32 {
+    %c0 = arith.constant 0 : i32
+    affine.yield %c0 : i32
+  }
+  return %0 : i32
+}

>From 54d2fd5743a1f914f2a66e422662ec20866193ee Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 12 Mar 2026 19:56:39 +0100
Subject: [PATCH 2/2] Update mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d3e2e77d10341..a25c0ed23456a 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2799,11 +2799,10 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
                                     OpBuilder &builder, Location loc,
                                     bool useOnlyFiniteValue) {
-  auto attr =
-      getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
-  if (!attr)
-    return {};
-  return arith::ConstantOp::create(builder, loc, attr);
+  if (auto attr =
+getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue))
+    return arith::ConstantOp::create(builder, loc, attr);
+  return {};
 }
 
 /// Return the value obtained by applying the reduction operation kind



More information about the Mlir-commits mailing list