[llvm-branch-commits] [mlir] MLIR bug fixes for LLVM 21.x release (PR #154587)

Mehdi Amini via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Aug 20 10:57:37 PDT 2025


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/154587

I went through the recent bug fixes in MLIR and cherry-picked the ones that seems good to have for the 21.x branch.

>From 877642f0b5018d0922dbbb72d95f1eefffbe09fd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= <tlongeri at google.com>
Date: Wed, 16 Jul 2025 00:52:35 -0700
Subject: [PATCH 01/22] [MLIR][Vector] Fix bug in ExtractStrideSlicesOp
 canonicalization (#147591)

The pattern would produce an invalid slice when some dimensions were
both sliced and broadcast.
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 39 +++++++++++++---------
 mlir/test/Dialect/Vector/canonicalize.mlir | 15 +++++++++
 2 files changed, 38 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8a08a157b25d7..7d615bfc12984 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4237,28 +4237,35 @@ class StridedSliceBroadcast final
     auto dstVecType = llvm::cast<VectorType>(op.getType());
     unsigned dstRank = dstVecType.getRank();
     unsigned rankDiff = dstRank - srcRank;
-    // Check if the most inner dimensions of the source of the broadcast are the
-    // same as the destination of the extract. If this is the case we can just
-    // use a broadcast as the original dimensions are untouched.
-    bool lowerDimMatch = true;
+    // Source dimensions can be broadcasted (1 -> n with n > 1) or sliced
+    // (n -> m with n > m). If they are originally both broadcasted *and*
+    // sliced, this can be simplified to just broadcasting.
+    bool needsSlice = false;
     for (unsigned i = 0; i < srcRank; i++) {
-      if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
-        lowerDimMatch = false;
+      if (srcVecType.getDimSize(i) != 1 &&
+          srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
+        needsSlice = true;
         break;
       }
     }
     Value source = broadcast.getSource();
-    // If the inner dimensions don't match, it means we need to extract from the
-    // source of the orignal broadcast and then broadcast the extracted value.
-    // We also need to handle degenerated cases where the source is effectively
-    // just a single scalar.
-    bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
-    if (!lowerDimMatch && !isScalarSrc) {
+    if (needsSlice) {
+      SmallVector<int64_t> offsets =
+          getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff);
+      SmallVector<int64_t> sizes =
+          getI64SubArray(op.getSizes(), /*dropFront=*/rankDiff);
+      for (unsigned i = 0; i < srcRank; i++) {
+        if (srcVecType.getDimSize(i) == 1) {
+          // In case this dimension was broadcasted *and* sliced, the offset
+          // and size need to be updated now that there is no broadcast before
+          // the slice.
+          offsets[i] = 0;
+          sizes[i] = 1;
+        }
+      }
       source = rewriter.create<ExtractStridedSliceOp>(
-          op->getLoc(), source,
-          getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
-          getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
-          getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
+          op->getLoc(), source, offsets, sizes,
+          getI64SubArray(op.getStrides(), /*dropFront=*/rankDiff));
     }
     rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
     return success();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 12187dd18012b..ea2343efd246e 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1379,6 +1379,21 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
 
 // -----
 
+// Check the case where the same dimension is both broadcasted and sliced 
+// CHECK-LABEL: func @extract_strided_broadcast5
+//  CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
+//       CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
+//       CHECK: return %[[V]]
+func.func @extract_strided_broadcast5(%arg0: vector<2x1xf32>) -> vector<2x4xf32> {
+ %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x8xf32>
+ %1 = vector.extract_strided_slice %0
+      {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]}
+      : vector<2x8xf32> to vector<2x4xf32>
+  return %1 : vector<2x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: consecutive_shape_cast
 //       CHECK:   %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
 //  CHECK-NEXT:   return %[[C]] : vector<4x4xf16>

>From e393d8f41d078080fbead918946c33d821afc330 Mon Sep 17 00:00:00 2001
From: Artemiy Bulavin <artemiyb at graphcore.ai>
Date: Wed, 16 Jul 2025 11:11:38 +0100
Subject: [PATCH 02/22] [MLIR] Fix use-after-frees when accessing DistinctAttr
 storage (#148666)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This PR fixes a use-after-free error that happens when `DistinctAttr`
instances are created within a `PassManager` running with crash recovery
enabled. The root cause is that `DistinctAttr` storage is allocated in a
thread_local allocator, which is destroyed when the crash recovery
thread joins, invalidating the storage.

Moreover, even without crash reproduction disabling multithreading on
the context will destroy the context's thread pool, and in turn delete
the threadlocal storage. This means a call to
`ctx->disableMulthithreading()` breaks the IR.

This PR replaces the thread local allocator with a synchronised
allocator that's shared between threads. This persists the lifetime of
allocated DistinctAttr storage instances to the lifetime of the context.

### Problem Details:

The `DistinctAttributeAllocator` uses a
`ThreadLocalCache<BumpPtrAllocator>` for lock-free allocation of
`DistinctAttr` storage in a multithreaded context. The issue occurs when
a `PassManager` is run with crash recovery (`runWithCrashRecovery`), the
pass pipeline is executed on a temporary thread spawned by
`llvm::CrashRecoveryContext`. Any `DistinctAttr`s created during this
execution have their storage allocated in the thread_local cache of this
temporary thread. When the thread joins, the thread_local storage is
destroyed, freeing the `DistinctAttr`s' memory. If this attribute is
accessed later, e.g. when printing, it results in a use-after-free.

As mentioned previously, this is also seen after creating some
`DistinctAttr`s and then calling `ctx->disableMulthithreading()`.

### Solution

`DistinctAttrStorageAllocator` uses a synchronised, shared allocator
instead of one wrapped in a `ThreadLocalCache`. The former is what
stores the allocator in transient thread_local storage.

### Testing:

A C++ unit test has been added to validate this fix. (I was previously
reproducing this failure with `mlir-opt` but I can no longer do so and I
am unsure why.)

-----

Note: This is a 2nd attempt at my previous PR
https://github.com/llvm/llvm-project/pull/128566 that was reverted in
https://github.com/llvm/llvm-project/pull/133000. I believe I've
addressed the TSAN and race condition concerns.
---
 mlir/lib/IR/AttributeDetail.h                 | 29 ++++++------
 mlir/lib/IR/MLIRContext.cpp                   |  1 +
 mlir/lib/Pass/PassCrashRecovery.cpp           |  9 +++-
 mlir/unittests/IR/CMakeLists.txt              |  1 +
 .../IR/DistinctAttributeAllocatorTest.cpp     | 45 +++++++++++++++++++
 5 files changed, 69 insertions(+), 16 deletions(-)
 create mode 100644 mlir/unittests/IR/DistinctAttributeAllocatorTest.cpp

diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 26d40ac3a38f6..cb9d21bf3e611 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -19,11 +19,9 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
-#include "mlir/Support/StorageUniquer.h"
-#include "mlir/Support/ThreadLocalCache.h"
 #include "llvm/ADT/APFloat.h"
-#include "llvm/ADT/PointerIntPair.h"
-#include "llvm/Support/TrailingObjects.h"
+#include "llvm/Support/Allocator.h"
+#include <mutex>
 
 namespace mlir {
 namespace detail {
@@ -396,27 +394,30 @@ class DistinctAttributeUniquer {
                                               Attribute referencedAttr);
 };
 
-/// An allocator for distinct attribute storage instances. It uses thread local
-/// bump pointer allocators stored in a thread local cache to ensure the storage
-/// is freed after the destruction of the distinct attribute allocator.
-class DistinctAttributeAllocator {
+/// An allocator for distinct attribute storage instances. Uses a synchronized
+/// BumpPtrAllocator to ensure thread-safety. The allocated storage is deleted
+/// when the DistinctAttributeAllocator is destroyed.
+class DistinctAttributeAllocator final {
 public:
   DistinctAttributeAllocator() = default;
-
   DistinctAttributeAllocator(DistinctAttributeAllocator &&) = delete;
   DistinctAttributeAllocator(const DistinctAttributeAllocator &) = delete;
   DistinctAttributeAllocator &
   operator=(const DistinctAttributeAllocator &) = delete;
 
-  /// Allocates a distinct attribute storage using a thread local bump pointer
-  /// allocator to enable synchronization free parallel allocations.
   DistinctAttrStorage *allocate(Attribute referencedAttr) {
-    return new (allocatorCache.get().Allocate<DistinctAttrStorage>())
+    std::scoped_lock<std::mutex> guard(allocatorMutex);
+    return new (allocator.Allocate<DistinctAttrStorage>())
         DistinctAttrStorage(referencedAttr);
-  }
+  };
 
 private:
-  ThreadLocalCache<llvm::BumpPtrAllocator> allocatorCache;
+  /// Used to allocate distict attribute storages. The managed memory is freed
+  /// automatically when the allocator instance is destroyed.
+  llvm::BumpPtrAllocator allocator;
+
+  /// Used to lock access to the allocator.
+  std::mutex allocatorMutex;
 };
 } // namespace detail
 } // namespace mlir
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 716d9c85a377d..06ec1c85fb4d5 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -31,6 +31,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/Mutex.h"
 #include "llvm/Support/RWMutex.h"
 #include "llvm/Support/ThreadPool.h"
diff --git a/mlir/lib/Pass/PassCrashRecovery.cpp b/mlir/lib/Pass/PassCrashRecovery.cpp
index b048ff9462392..c6fb1d737d508 100644
--- a/mlir/lib/Pass/PassCrashRecovery.cpp
+++ b/mlir/lib/Pass/PassCrashRecovery.cpp
@@ -414,14 +414,19 @@ struct FileReproducerStream : public mlir::ReproducerStream {
 
 LogicalResult PassManager::runWithCrashRecovery(Operation *op,
                                                 AnalysisManager am) {
+  const bool threadingEnabled = getContext()->isMultithreadingEnabled();
   crashReproGenerator->initialize(getPasses(), op, verifyPasses);
 
   // Safely invoke the passes within a recovery context.
   LogicalResult passManagerResult = failure();
   llvm::CrashRecoveryContext recoveryContext;
-  recoveryContext.RunSafelyOnThread(
-      [&] { passManagerResult = runPasses(op, am); });
+  const auto runPassesFn = [&] { passManagerResult = runPasses(op, am); };
+  if (threadingEnabled)
+    recoveryContext.RunSafelyOnThread(runPassesFn);
+  else
+    recoveryContext.RunSafely(runPassesFn);
   crashReproGenerator->finalize(op, passManagerResult);
+
   return passManagerResult;
 }
 
diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt
index d22afb3003e76..a46e64718dab9 100644
--- a/mlir/unittests/IR/CMakeLists.txt
+++ b/mlir/unittests/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_unittest(MLIRIRTests
   AttrTypeReplacerTest.cpp
   Diagnostic.cpp
   DialectTest.cpp
+  DistinctAttributeAllocatorTest.cpp
   InterfaceTest.cpp
   IRMapping.cpp
   InterfaceAttachmentTest.cpp
diff --git a/mlir/unittests/IR/DistinctAttributeAllocatorTest.cpp b/mlir/unittests/IR/DistinctAttributeAllocatorTest.cpp
new file mode 100644
index 0000000000000..99067d09f7bed
--- /dev/null
+++ b/mlir/unittests/IR/DistinctAttributeAllocatorTest.cpp
@@ -0,0 +1,45 @@
+//=== DistinctAttributeAllocatorTest.cpp - DistinctAttr storage alloc test ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "gtest/gtest.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "llvm/Support/CrashRecoveryContext.h"
+#include <thread>
+
+using namespace mlir;
+
+//
+// Test that a DistinctAttr that is created on a separate thread does
+// not have its storage deleted when the thread joins.
+//
+TEST(DistinctAttributeAllocatorTest, TestAttributeWellFormedAfterThreadJoin) {
+  MLIRContext ctx;
+  OpBuilder builder(&ctx);
+  DistinctAttr attr;
+
+  std::thread t([&ctx, &attr]() {
+    attr = DistinctAttr::create(UnitAttr::get(&ctx));
+    ASSERT_TRUE(attr);
+  });
+  t.join();
+
+  // If the attribute storage got deleted after the thread joins (which we don't
+  // want) then trying to access it triggers an assert in Debug mode, and a
+  // crash otherwise. Run this in a CrashRecoveryContext to avoid bringing down
+  // the whole test suite if this test fails. Additionally, MSAN and/or TSAN
+  // should raise failures here if the attribute storage was deleted.
+  llvm::CrashRecoveryContext crc;
+  EXPECT_TRUE(crc.RunSafely([attr]() { (void)attr.getAbstractAttribute(); }));
+  EXPECT_TRUE(
+      crc.RunSafely([attr]() { (void)*cast<Attribute>(attr).getImpl(); }));
+
+  ASSERT_TRUE(attr);
+}

>From e8c2aeb40f97db7c17b547c98e233b618c16d35c Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Fri, 18 Jul 2025 09:28:29 +0800
Subject: [PATCH 03/22] [mlir][mesh] Add null check for dyn_cast to prevent
 crash (#149266)

This PR adds a null check for dyn_cast result before use to prevent
crash, and use `isa` instead `dyn_cast` to make code clean. Fixes
#148619.
---
 .../mlir/Dialect/Mesh/Transforms/Simplifications.h   |  8 ++++----
 mlir/test/Dialect/Mesh/simplifications.mlir          | 12 ++++++++++++
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index 3f1041cb25103..243dbf081b999 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -62,9 +62,11 @@ void populateAllReduceEndomorphismSimplificationPatterns(
   auto isEndomorphismOp = [reduction](Operation *op,
                                       std::optional<Operation *> referenceOp) {
     auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+    if (!allReduceOp)
+      return false;
     auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
     auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
-    if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
+    if (inType.getElementType() != outType.getElementType() ||
         allReduceOp.getReduction() != reduction) {
       return false;
     }
@@ -87,9 +89,7 @@ void populateAllReduceEndomorphismSimplificationPatterns(
     return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
            inType.getElementType() == refType.getElementType();
   };
-  auto isAlgebraicOp = [](Operation *op) {
-    return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
-  };
+  auto isAlgebraicOp = [](Operation *op) { return isa<AlgebraicOp>(op); };
 
   using ConcreteEndomorphismSimplification = EndomorphismSimplification<
       std::decay_t<decltype(getEndomorphismOpOperand)>,
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
index 2540fbf9510c4..e955f4c134259 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -165,3 +165,15 @@ func.func @all_reduce_arith_minsi_endomorphism(
   // CHECK: return %[[ALL_REDUCE_RES]]
   return %2 : tensor<5xi32>
 }
+
+// Ensure this case without endomorphism op not crash.
+// CHECK-LABEL: func.func @no_endomorphism_op
+func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 {
+  %c0 = arith.constant 0 : index
+  %c1_i64 = arith.constant 1 : i64
+  // CHECK: tensor.extract
+  %extracted = tensor.extract %arg0[%c0] : tensor<2xi64>
+  // CHECK: arith.maxsi
+  %0 = arith.maxsi %extracted, %c1_i64 : i64
+  return %0 : i64
+}

>From 56a5ddb4be44e00033387ac3059a8d1fccb2ab1c Mon Sep 17 00:00:00 2001
From: Daniel Garvey <34486624+dan-garvey at users.noreply.github.com>
Date: Thu, 24 Jul 2025 12:06:41 -0500
Subject: [PATCH 04/22] [mlir] Fix missing import (#150330)

building this file would fail when
MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS = 1

Signed-off-by: dan <danimal197 at gmail.com>
---
 mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index b82d850413946..607b86cb86315 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ScopedPrinter.h"
 #include "llvm/Support/raw_ostream.h"

>From bb304ccf295dff5217bb3eec0ae3cf63eb7d6a83 Mon Sep 17 00:00:00 2001
From: ronigoldman22 <156088210+ronigoldman22 at users.noreply.github.com>
Date: Fri, 25 Jul 2025 03:24:04 +0300
Subject: [PATCH 05/22] Fix Bug in RemoveDeadValues Pass (#148437)

This patch fixes a bug in the RemoveDeadValues pass where unused
function arguments were not removed from the function signature in an
edge case where the function returns void.
A corresponding test was added to the MLIR LIT test suite to cover this
case.
---
 mlir/lib/Transforms/RemoveDeadValues.cpp     |  4 ++--
 mlir/test/Transforms/remove-dead-values.mlir | 23 ++++++++++++++++++++
 2 files changed, 25 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 608bdcb948176..ddd5f2ba1a7b7 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -345,8 +345,6 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   // since it forwards only to non-live value(s) (%1#1).
   Operation *lastReturnOp = funcOp.back().getTerminator();
   size_t numReturns = lastReturnOp->getNumOperands();
-  if (numReturns == 0)
-    return;
   BitVector nonLiveRets(numReturns, true);
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
@@ -368,6 +366,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
 
   // Do (5) and (6).
+  if (numReturns == 0)
+    return;
   for (SymbolTable::SymbolUse use : uses) {
     Operation *callOp = use.getUser();
     assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 3af95db3c0e24..9ded6a30d9c95 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -548,3 +548,26 @@ func.func @test_atomic_yield(%I: memref<10xf32>, %idx : index) {
   func.return
 }
 
+// -----
+
+// CHECK-LABEL: module @return_void_with_unused_argument
+module @return_void_with_unused_argument {
+  // CHECK-LABEL: func.func private @fn_return_void_with_unused_argument
+  // CHECK-SAME: (%[[ARG0_FN:.*]]: i32)
+  func.func private @fn_return_void_with_unused_argument(%arg0: i32, %arg1: memref<4xi32>) -> () {
+    %sum = arith.addi %arg0, %arg0 : i32
+    %c0 = arith.constant 0 : index
+    %buf = memref.alloc() : memref<1xi32>
+    memref.store %sum, %buf[%c0] : memref<1xi32>
+    return
+  }
+  // CHECK-LABEL: func.func @main
+  // CHECK-SAME: (%[[ARG0_MAIN:.*]]: i32)
+  // CHECK: call @fn_return_void_with_unused_argument(%[[ARG0_MAIN]]) : (i32) -> ()
+  func.func @main(%arg0: i32) -> memref<4xi32> {
+    %unused = memref.alloc() : memref<4xi32>
+    call @fn_return_void_with_unused_argument(%arg0, %unused) : (i32, memref<4xi32>) -> ()
+    return %unused : memref<4xi32>
+  }
+}
+

>From ab9a1d42766afe38db38d827ab5c8d7e62c1c4ca Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sat, 26 Jul 2025 11:07:06 +0200
Subject: [PATCH 06/22] [mlir][SCF] Do not access erased operation in
 `scf.while` lowering (#150741)

Do not access the erased `scf.while` operation in the lowering pattern.
That won't work anymore in a One-Shot Dialect Conversion and triggers a
use-after-free sanitizer error.

After the One-Shot Dialect Conversion refactoring, a
`ConversionPatternRewriter` will behave more like a normal
`PatternRewriter`.
---
 mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 0df91a243d07a..523dc463a0da1 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -582,6 +582,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
   // block. This should be reconsidered if we allow break/continue in SCF.
   rewriter.setInsertionPointToEnd(before);
   auto condOp = cast<ConditionOp>(before->getTerminator());
+  SmallVector<Value> args = llvm::to_vector(condOp.getArgs());
   rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
                                                 after, condOp.getArgs(),
                                                 continuation, ValueRange());
@@ -593,7 +594,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
 
   // Replace the op with values "yielded" from the "before" region, which are
   // visible by dominance.
-  rewriter.replaceOp(whileOp, condOp.getArgs());
+  rewriter.replaceOp(whileOp, args);
 
   return success();
 }

>From d6da92d3415a3518d78dea5d86cb3633c1692b6d Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 4 Aug 2025 13:22:59 +0100
Subject: [PATCH 07/22] [mlir][OpenMP][NFC] Fix gcc 14 warning (#151941)

GCC couldn't tell that the enum is checked exhaustively and so was
warning about there being no return on this path from the function.
---
 .../Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp   | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3185f28fe6681..f8ea6ee07447b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3395,6 +3395,7 @@ static llvm::omp::Directive convertCancellationConstructType(
   case omp::ClauseCancellationConstructType::Taskgroup:
     return llvm::omp::Directive::OMPD_taskgroup;
   }
+  llvm_unreachable("Unhandled cancellation construct type");
 }
 
 static LogicalResult

>From 1c7f1c9a88a733b483475fe5f7c059e727d595a4 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Mon, 4 Aug 2025 23:16:14 -0700
Subject: [PATCH 08/22] [mlir] Clone attrs of unregistered dialect ops
 (#151847)

`Operation::clone` does not clone the properties of unregistered ops.
This patch modifies the property initialization for unregistered ops to
initialize properties as attributes.

fixes #151640

---------

Signed-off-by: Boyana Norris <brnorris03 at gmail.com>
---
 mlir/lib/IR/MLIRContext.cpp  |  2 ++
 mlir/test/IR/test-clone.mlir | 35 +++++++++++++++++++++++++++++++++--
 2 files changed, 35 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 06ec1c85fb4d5..2d5381d43f863 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -884,6 +884,8 @@ int OperationName::UnregisteredOpModel::getOpPropertyByteSize() {
 void OperationName::UnregisteredOpModel::initProperties(
     OperationName opName, OpaqueProperties storage, OpaqueProperties init) {
   new (storage.as<Attribute *>()) Attribute();
+  if (init)
+    *storage.as<Attribute *>() = *init.as<Attribute *>();
 }
 void OperationName::UnregisteredOpModel::deleteProperties(
     OpaqueProperties prop) {
diff --git a/mlir/test/IR/test-clone.mlir b/mlir/test/IR/test-clone.mlir
index 0c07593aef32d..f723efc1a2c53 100644
--- a/mlir/test/IR/test-clone.mlir
+++ b/mlir/test/IR/test-clone.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(test-clone))" | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(func.func(test-clone))" --split-input-file | FileCheck %s
 
 module {
   func.func @fixpoint(%arg1 : i32) -> i32 {
@@ -18,7 +18,8 @@ module {
 // CHECK-NEXT: notifyOperationInserted: test.yield
 // CHECK-NEXT: notifyOperationInserted: func.return
 
-// CHECK:   func @fixpoint(%[[arg0:.+]]: i32) -> i32 {
+// CHECK-LABEL: func @fixpoint
+// CHECK-SAME:       (%[[arg0:.+]]: i32) -> i32 {
 // CHECK-NEXT:     %[[i0:.+]] = "test.use"(%[[arg0]]) ({
 // CHECK-NEXT:       %[[r2:.+]] = "test.use2"(%[[arg0]]) ({
 // CHECK-NEXT:         "test.yield2"(%[[arg0]]) : (i32) -> ()
@@ -33,3 +34,33 @@ module {
 // CHECK-NEXT:     }) : (i32) -> i32
 // CHECK-NEXT:     return %[[i1]] : i32
 // CHECK-NEXT:   }
+
+// -----
+
+func.func @clone_unregistered_with_attrs() {
+  "unregistered.foo"() <{bar = 1 : i64, flag = true, name = "test", value = 3.14 : f32}> : () -> ()
+  "unregistered.bar"() : () -> ()
+  "unregistered.empty_dict"() <{}> : () -> ()
+  "unregistered.complex"() <{
+    array = [1, 2, 3],
+    dict = {key1 = 42 : i32, key2 = "value"},
+    nested = {inner = {deep = 100 : i64}}
+  }> : () -> ()
+  return
+}
+
+// CHECK: notifyOperationInserted: unregistered.foo
+// CHECK-NEXT: notifyOperationInserted: unregistered.bar
+// CHECK-NEXT: notifyOperationInserted: unregistered.empty_dict
+// CHECK-NEXT: notifyOperationInserted: unregistered.complex
+// CHECK-NEXT: notifyOperationInserted: func.return
+
+// CHECK-LABEL:  func @clone_unregistered_with_attrs() {
+// CHECK-NEXT:     "unregistered.foo"() <{bar = 1 : i64, flag = true, name = "test", value = [[PI:.+]] : f32}> : () -> ()
+// CHECK-NEXT:     "unregistered.bar"() : () -> ()
+// CHECK-NEXT:     "unregistered.empty_dict"() <{}> : () -> ()
+// CHECK-NEXT:     "unregistered.complex"() <{array = [1, 2, 3], dict = {key1 = 42 : i32, key2 = "value"}, nested = {inner = {deep = 100 : i64}}}> : () -> ()
+// CHECK-NEXT:     "unregistered.foo"() <{bar = 1 : i64, flag = true, name = "test", value = [[PI]] : f32}> : () -> ()
+// CHECK-NEXT:     "unregistered.bar"() : () -> ()
+// CHECK-NEXT:     "unregistered.empty_dict"() <{}> : () -> ()
+// CHECK-NEXT:     "unregistered.complex"() <{array = [1, 2, 3], dict = {key1 = 42 : i32, key2 = "value"}, nested = {inner = {deep = 100 : i64}}}> : () -> ()

>From 753de6e7d37a628c35d5a20c52f596327bfab6ea Mon Sep 17 00:00:00 2001
From: Philip Lassen <plassen at groq.com>
Date: Mon, 4 Aug 2025 23:36:02 -0700
Subject: [PATCH 09/22] [NFC][mlir] Update DataFlowFramework.h to be compatible
 with clang c++23 (#152040)

This change makes `DataFlowFramework.h` compatible with `clang++` and
`--std=c++23`.

Previously clang was checking the templated `DataFlowSolver::eraseState`
body before being instantiated. This resulted in issues with incomplete
types, and happened at least with `clang++-19`.

This is fixed by moving the definition of `DataFlowSolver::eraseState`
after the `AnalysisState`'s full class declaration.

For full context:
-
https://discourse.llvm.org/t/what-is-the-status-of-c-23-support-in-mlir/87674/12
---
 .../include/mlir/Analysis/DataFlowFramework.h | 54 +++++++++++--------
 1 file changed, 31 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 49862927caff2..e364570c8b531 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -354,29 +354,7 @@ class DataFlowSolver {
 
   /// Erase any analysis state associated with the given lattice anchor.
   template <typename AnchorT>
-  void eraseState(AnchorT anchor) {
-    LatticeAnchor latticeAnchor(anchor);
-
-    // Update equivalentAnchorMap.
-    for (auto &&[TypeId, eqClass] : equivalentAnchorMap) {
-      if (!eqClass.contains(latticeAnchor)) {
-        continue;
-      }
-      llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
-          eqClass.findLeader(latticeAnchor);
-
-      // Update analysis states with new leader if needed.
-      if (*leaderIt == latticeAnchor && ++leaderIt != eqClass.member_end()) {
-        analysisStates[*leaderIt][TypeId] =
-            std::move(analysisStates[latticeAnchor][TypeId]);
-      }
-
-      eqClass.erase(latticeAnchor);
-    }
-
-    // Update analysis states.
-    analysisStates.erase(latticeAnchor);
-  }
+  void eraseState(AnchorT anchor);
 
   /// Erase all analysis states.
   void eraseAllStates() {
@@ -560,6 +538,36 @@ class AnalysisState {
   friend class DataFlowSolver;
 };
 
+//===----------------------------------------------------------------------===//
+// DataFlowSolver definition
+//===----------------------------------------------------------------------===//
+// This method is defined outside `DataFlowSolver` and after `AnalysisState`
+// to prevent issues around `AnalysisState` being used before it is defined.
+template <typename AnchorT>
+void DataFlowSolver::eraseState(AnchorT anchor) {
+  LatticeAnchor latticeAnchor(anchor);
+
+  // Update equivalentAnchorMap.
+  for (auto &&[TypeId, eqClass] : equivalentAnchorMap) {
+    if (!eqClass.contains(latticeAnchor)) {
+      continue;
+    }
+    llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
+        eqClass.findLeader(latticeAnchor);
+
+    // Update analysis states with new leader if needed.
+    if (*leaderIt == latticeAnchor && ++leaderIt != eqClass.member_end()) {
+      analysisStates[*leaderIt][TypeId] =
+          std::move(analysisStates[latticeAnchor][TypeId]);
+    }
+
+    eqClass.erase(latticeAnchor);
+  }
+
+  // Update analysis states.
+  analysisStates.erase(latticeAnchor);
+}
+
 //===----------------------------------------------------------------------===//
 // DataFlowAnalysis
 //===----------------------------------------------------------------------===//

>From d188f5a45ac87c4d7ef13324d882f3ae73d89e16 Mon Sep 17 00:00:00 2001
From: Boyana Norris <brnorris03 at gmail.com>
Date: Thu, 7 Aug 2025 06:54:30 -0700
Subject: [PATCH 10/22] [mlir][cmake] Fix MLIR shared library installation
 (#152195)

When `LLVM_INSTALL_TOOLCHAIN_ONLY=ON`, the MLIR shared library
(`libMLIR*`) is not installed even though it is built with the
`INSTALL_WITH_TOOLCHAIN` argument to the `add_mlir_library` cmake
function. This patch ensures that `libMLIR*` is installed when
`LLVM_INSTALL_TOOLCHAIN_ONLY=ON`.

Patch verified
[here](https://github.com/llvm/llvm-project/issues/151247#issuecomment-3156387055).

fixes #151247
---
 mlir/cmake/modules/AddMLIR.cmake | 44 +++++++++++++++++---------------
 1 file changed, 24 insertions(+), 20 deletions(-)

diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake
index ff4269ed7acd2..14eefb50ca714 100644
--- a/mlir/cmake/modules/AddMLIR.cmake
+++ b/mlir/cmake/modules/AddMLIR.cmake
@@ -388,6 +388,9 @@ function(add_mlir_library name)
 
   if(TARGET ${name})
     target_link_libraries(${name} INTERFACE ${LLVM_COMMON_LIBS})
+    if(ARG_INSTALL_WITH_TOOLCHAIN)
+      set_target_properties(${name} PROPERTIES MLIR_INSTALL_WITH_TOOLCHAIN TRUE)
+    endif()
     if(NOT ARG_DISABLE_INSTALL)
       add_mlir_library_install(${name})
     endif()
@@ -617,28 +620,29 @@ endfunction(add_mlir_aggregate)
 # This is usually done as part of add_mlir_library but is broken out for cases
 # where non-standard library builds can be installed.
 function(add_mlir_library_install name)
-  if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY)
-  get_target_export_arg(${name} MLIR export_to_mlirtargets UMBRELLA mlir-libraries)
-  install(TARGETS ${name}
-    COMPONENT ${name}
-    ${export_to_mlirtargets}
-    LIBRARY DESTINATION lib${LLVM_LIBDIR_SUFFIX}
-    ARCHIVE DESTINATION lib${LLVM_LIBDIR_SUFFIX}
-    RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}"
-    # Note that CMake will create a directory like:
-    #   objects-${CMAKE_BUILD_TYPE}/obj.LibName
-    # and put object files there.
-    OBJECTS DESTINATION lib${LLVM_LIBDIR_SUFFIX}
-  )
+  get_target_property(_install_with_toolchain ${name} MLIR_INSTALL_WITH_TOOLCHAIN)
+  if (NOT LLVM_INSTALL_TOOLCHAIN_ONLY OR _install_with_toolchain)
+    get_target_export_arg(${name} MLIR export_to_mlirtargets UMBRELLA mlir-libraries)
+    install(TARGETS ${name}
+      COMPONENT ${name}
+      ${export_to_mlirtargets}
+      LIBRARY DESTINATION lib${LLVM_LIBDIR_SUFFIX}
+      ARCHIVE DESTINATION lib${LLVM_LIBDIR_SUFFIX}
+      RUNTIME DESTINATION "${CMAKE_INSTALL_BINDIR}"
+      # Note that CMake will create a directory like:
+      #   objects-${CMAKE_BUILD_TYPE}/obj.LibName
+      # and put object files there.
+      OBJECTS DESTINATION lib${LLVM_LIBDIR_SUFFIX}
+    )
 
-  if (NOT LLVM_ENABLE_IDE)
-    add_llvm_install_targets(install-${name}
-                            DEPENDS ${name}
-                            COMPONENT ${name})
-  endif()
-  set_property(GLOBAL APPEND PROPERTY MLIR_ALL_LIBS ${name})
+    if (NOT LLVM_ENABLE_IDE)
+      add_llvm_install_targets(install-${name}
+                              DEPENDS ${name}
+                              COMPONENT ${name})
+    endif()
+    set_property(GLOBAL APPEND PROPERTY MLIR_ALL_LIBS ${name})
+    set_property(GLOBAL APPEND PROPERTY MLIR_EXPORTS ${name})
   endif()
-  set_property(GLOBAL APPEND PROPERTY MLIR_EXPORTS ${name})
 endfunction()
 
 # Declare an mlir library which is part of the public C-API.

>From 705afe3e5c8403559f47f6d5f577aeab4c38af96 Mon Sep 17 00:00:00 2001
From: Sasa Vuckovic <svuckovic at tenstorrent.com>
Date: Fri, 8 Aug 2025 12:33:56 +0200
Subject: [PATCH 11/22] [MLIR] Make `PassPipelineOptions` virtually inheriting
 from PassOptions to allow diamond inheritance (#146370)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Problem

Given 3 pipelines, A, B, and a superset pipeline AB that runs both the A
& B pipelines, it is not easy to manage their options - one needs to
manually recreate all options from A and B into AB, and maintain them.
This is tedious.

## Proposed solution
Ideally, AB options class inherits from both A and B options, making the
maintenance effortless. Today though, this causes problems as their base
classes `PassPipelineOptions<A>` and `PassPipelineOptions<B>` both
inherit from `mlir::detail::PassOptions`, leading to the so called
"diamond inheritance problem", i.e. multiple definitions of the same
symbol, in this case parseFromString that is defined in
mlir::detail::PassOptions.

Visually, the inheritance looks like this:

```
                         mlir::detail::PassOptions
                            ↑                  ↑
                            |                  |
           PassPipelineOptions<A>      PassPipelineOptions<B>
                            ↑                  ↑
                            |                  |
                         AOptions           BOptions
                            ↑                  ↑
                            +---------+--------+
                                      |
                                  ABOptions
```

A proposed fix is to use the common solution to the diamond inheritance
problem - virtual inheritance.
---
 mlir/include/mlir/Pass/PassOptions.h         |  2 +-
 mlir/test/Pass/pipeline-options-parsing.mlir | 10 ++++
 mlir/test/lib/Pass/TestPassManager.cpp       | 59 ++++++++++++++++++++
 3 files changed, 70 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h
index e1f16c6158ad5..0c71f78b52d3d 100644
--- a/mlir/include/mlir/Pass/PassOptions.h
+++ b/mlir/include/mlir/Pass/PassOptions.h
@@ -377,7 +377,7 @@ class PassOptions : protected llvm::cl::SubCommand {
 ///   ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
 /// };
 template <typename T>
-class PassPipelineOptions : public detail::PassOptions {
+class PassPipelineOptions : public virtual detail::PassOptions {
 public:
   /// Factory that parses the provided options and returns a unique_ptr to the
   /// struct.
diff --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir
index 9385d353faf95..03ac38ea16112 100644
--- a/mlir/test/Pass/pipeline-options-parsing.mlir
+++ b/mlir/test/Pass/pipeline-options-parsing.mlir
@@ -13,6 +13,7 @@
 // RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(builtin.module(func.func(test-options-pass{list=3}), func.func(test-options-pass{enum=one list=1,2,3,4 string=foo"bar"baz})))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_6 %s
 // RUN: mlir-opt %s -verify-each=false '-test-options-super-pass-pipeline=super-list={{enum=zero list=1 string=foo},{enum=one list=2 string="bar"},{enum=two list=3 string={baz}}}' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s
 // RUN: mlir-opt %s -verify-each=false -pass-pipeline='builtin.module(func.func(test-options-super-pass{list={{enum=zero list={1} string=foo },{enum=one list={2} string=bar },{enum=two list={3} string=baz }}}))' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_7 %s
+// RUN: mlir-opt %s -verify-each=false -test-options-super-set-ab-pipeline='foo=true bar=false' -dump-pass-pipeline 2>&1 | FileCheck --check-prefix=CHECK_11 %s
 
 
 // This test checks that lists-of-nested-options like 'option1={...},{....}' can be parsed
@@ -106,3 +107,12 @@
 // CHECK_10-NEXT:     test-options-pass{enum=zero  string= string-list={,}}
 // CHECK_10-NEXT:   )
 // CHECK_10-NEXT: )
+
+// CHECK_11:      builtin.module(
+// CHECK_11-NEXT:   func.func(
+// CHECK_11-NEXT:     test-options-pass-a
+// CHECK_11-NEXT:   )
+// CHECK_11-NEXT:   func.func(
+// CHECK_11-NEXT:     test-options-pass-b
+// CHECK_11-NEXT:   )
+// CHECK_11-NEXT: )
diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp
index 7afe2109f04db..2b5f75ef53f16 100644
--- a/mlir/test/lib/Pass/TestPassManager.cpp
+++ b/mlir/test/lib/Pass/TestPassManager.cpp
@@ -133,6 +133,51 @@ struct TestOptionsSuperPass
       llvm::cl::desc("Example list of PassPipelineOptions option")};
 };
 
+struct TestOptionsPassA
+    : public PassWrapper<TestOptionsPassA, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPassA)
+
+  struct Options : public PassPipelineOptions<Options> {
+    Option<bool> foo{*this, "foo", llvm::cl::desc("Example boolean option")};
+  };
+
+  TestOptionsPassA() = default;
+  TestOptionsPassA(const TestOptionsPassA &) : PassWrapper() {}
+  TestOptionsPassA(const Options &options) { this->options.foo = options.foo; }
+
+  void runOnOperation() final {}
+  StringRef getArgument() const final { return "test-options-pass-a"; }
+  StringRef getDescription() const final {
+    return "Test superset options parsing capabilities - subset A";
+  }
+
+  Options options;
+};
+
+struct TestOptionsPassB
+    : public PassWrapper<TestOptionsPassB, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOptionsPassB)
+
+  struct Options : public PassPipelineOptions<Options> {
+    Option<bool> bar{*this, "bar", llvm::cl::desc("Example boolean option")};
+  };
+
+  TestOptionsPassB() = default;
+  TestOptionsPassB(const TestOptionsPassB &) : PassWrapper() {}
+  TestOptionsPassB(const Options &options) { this->options.bar = options.bar; }
+
+  void runOnOperation() final {}
+  StringRef getArgument() const final { return "test-options-pass-b"; }
+  StringRef getDescription() const final {
+    return "Test superset options parsing capabilities - subset B";
+  }
+
+  Options options;
+};
+
+struct TestPipelineOptionsSuperSetAB : TestOptionsPassA::Options,
+                                       TestOptionsPassB::Options {};
+
 /// A test pass that always aborts to enable testing the crash recovery
 /// mechanism of the pass manager.
 struct TestCrashRecoveryPass
@@ -270,6 +315,9 @@ void registerPassManagerTestPass() {
   PassRegistration<TestOptionsPass>();
   PassRegistration<TestOptionsSuperPass>();
 
+  PassRegistration<TestOptionsPassA>();
+  PassRegistration<TestOptionsPassB>();
+
   PassRegistration<TestModulePass>();
 
   PassRegistration<TestFunctionPass>();
@@ -306,5 +354,16 @@ void registerPassManagerTestPass() {
           [](OpPassManager &pm, const TestOptionsSuperPass::Options &options) {
             pm.addPass(std::make_unique<TestOptionsSuperPass>(options));
           });
+
+  PassPipelineRegistration<TestPipelineOptionsSuperSetAB>
+      registerPipelineOptionsSuperSetABPipeline(
+          "test-options-super-set-ab-pipeline",
+          "Parses options of PassPipelineOptions using pass pipeline "
+          "registration",
+          [](OpPassManager &pm, const TestPipelineOptionsSuperSetAB &options) {
+            // Pass superset AB options to subset options A and B
+            pm.addPass(std::make_unique<TestOptionsPassA>(options));
+            pm.addPass(std::make_unique<TestOptionsPassB>(options));
+          });
 }
 } // namespace mlir

>From 2065905dde7d69eb75f03f14e6e379da7bac8b33 Mon Sep 17 00:00:00 2001
From: yronglin <yronglin777 at gmail.com>
Date: Sun, 10 Aug 2025 09:11:34 +0800
Subject: [PATCH 12/22] [NFC][mlir] Fully qualify namespace to avoid an MSVC
 bug (#152860)

VS17.6 has a name lookup issue, and was fixed in VS17.7, it impact down
stream MLIR based project. This MR add full qualifiers to workaround
this issue.
Reproducer: https://godbolt.org/z/Ea6e1Kc3E

Signed-off-by: yronglin <yronglin777 at gmail.com>
---
 mlir/include/mlir/IR/PatternMatch.h | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index afeb784b85a12..9565eaf4da762 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -311,14 +311,14 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
 /// opposed to a raw Operation.
 template <typename SourceOp>
 struct OpRewritePattern
-    : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+    : public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
 
   /// Patterns must specify the root operation name they match against, and can
   /// also specify the benefit of the pattern matching and a list of generated
   /// ops.
   OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1,
                    ArrayRef<StringRef> generatedNames = {})
-      : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
+      : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
             SourceOp::getOperationName(), benefit, context, generatedNames) {}
 };
 
@@ -327,10 +327,10 @@ struct OpRewritePattern
 /// of a raw Operation.
 template <typename SourceOp>
 struct OpInterfaceRewritePattern
-    : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+    : public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
 
   OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
-      : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
+      : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
             Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
             benefit, context) {}
 };

>From ca23337685fefa4fbcb32b212ab3cdc7817c889e Mon Sep 17 00:00:00 2001
From: Nick Smith <127986401+nsmithtt at users.noreply.github.com>
Date: Tue, 12 Aug 2025 13:27:05 -0500
Subject: [PATCH 13/22] [MLIR][Python] MLIR Enum Python bindings infinite
 recursion (#151584) (#151588)

Fixes an infinite recursion bug when using I32BitEnumAttrCaseGroup with
python bindings.

For more info, see issue:
- https://github.com/llvm/llvm-project/issues/151584
---
 mlir/test/mlir-tblgen/enums-python-bindings.td | 18 ++++++++++++------
 .../tools/mlir-tblgen/EnumPythonBindingGen.cpp |  2 +-
 2 files changed, 13 insertions(+), 7 deletions(-)

diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
index 1c5567f54a5f4..cd23b6a2effb9 100644
--- a/mlir/test/mlir-tblgen/enums-python-bindings.td
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -62,12 +62,15 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>
 // CHECK: def _myenum64(x, context):
 // CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
 
+def User : I32BitEnumAttrCaseBit<"User", 0, "user">;
+def Group : I32BitEnumAttrCaseBit<"Group", 1, "group">;
+def Other : I32BitEnumAttrCaseBit<"Other", 2, "other">;
+
 def TestBitEnum
-    : I32BitEnumAttr<"TestBitEnum", "", [
-        I32BitEnumAttrCaseBit<"User", 0, "user">,
-        I32BitEnumAttrCaseBit<"Group", 1, "group">,
-        I32BitEnumAttrCaseBit<"Other", 2, "other">,
-      ]> {
+    : I32BitEnumAttr<
+          "TestBitEnum", "",
+          [User, Group, Other,
+           I32BitEnumAttrCaseGroup<"Any", [User, Group, Other], "any">]> {
   let genSpecializedAttr = 0;
   let separator = " | ";
 }
@@ -79,9 +82,10 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
 // CHECK:     User = 1
 // CHECK:     Group = 2
 // CHECK:     Other = 4
+// CHECK:     Any = 7
 
 // CHECK:     def __iter__(self):
-// CHECK:         return iter([case for case in type(self) if (self & case) is case])
+// CHECK:         return iter([case for case in type(self) if (self & case) is case and self is not case])
 // CHECK:     def __len__(self):
 // CHECK:         return bin(self).count("1")
 
@@ -94,6 +98,8 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
 // CHECK:             return "group"
 // CHECK:         if self is TestBitEnum.Other:
 // CHECK:             return "other"
+// CHECK:         if self is TestBitEnum.Any:
+// CHECK:             return "any"
 // CHECK:         raise ValueError("Unknown TestBitEnum enum entry.")
 
 // CHECK: @register_attribute_builder("TestBitEnum")
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 8e2d6114e48eb..acc9b61d7121c 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -64,7 +64,7 @@ static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
   if (enumInfo.isBitEnum()) {
     os << formatv("    def __iter__(self):\n"
                   "        return iter([case for case in type(self) if "
-                  "(self & case) is case])\n");
+                  "(self & case) is case and self is not case])\n");
     os << formatv("    def __len__(self):\n"
                   "        return bin(self).count(\"1\")\n");
     os << "\n";

>From 467e411049ce0977f77d958875b3796324119852 Mon Sep 17 00:00:00 2001
From: Baz <batzorig1691 at gmail.com>
Date: Wed, 13 Aug 2025 19:23:04 +0900
Subject: [PATCH 14/22] [mlir][ExecutionEngine] fix default free function in
 `OwningMemRef`. (#153133)

`basePtr` should be freed instead of `data` because it is the one which
is storing the output of `malloc`. In `allocAligned()`, the `data` is
malloced and then assigned to `basePtr`.
---
 mlir/include/mlir/ExecutionEngine/MemRefUtils.h | 2 +-
 mlir/unittests/ExecutionEngine/Invoke.cpp       | 8 +++++++-
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
index 6e72f7c23bdcf..d66d757cb7a8e 100644
--- a/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
+++ b/mlir/include/mlir/ExecutionEngine/MemRefUtils.h
@@ -151,7 +151,7 @@ class OwningMemRef {
       AllocFunType allocFun = &::malloc,
       std::function<void(StridedMemRefType<T, Rank>)> freeFun =
           [](StridedMemRefType<T, Rank> descriptor) {
-            ::free(descriptor.data);
+            ::free(descriptor.basePtr);
           })
       : freeFunc(freeFun) {
     if (shapeAlloc.empty())
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index 887db227cfc4b..312b10f28143f 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -205,7 +205,13 @@ TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
   };
   int64_t shape[] = {k, m};
   int64_t shapeAlloc[] = {k + 1, m + 1};
-  OwningMemRef<float, 2> a(shape, shapeAlloc, init);
+  // Use a large alignment to stress the case where the memref data/basePtr are
+  // disjoint.
+  int alignment = 8192;
+  OwningMemRef<float, 2> a(shape, shapeAlloc, init, alignment);
+  ASSERT_EQ(
+      (void *)(((uintptr_t)a->basePtr + alignment - 1) & ~(alignment - 1)),
+      a->data);
   ASSERT_EQ(a->sizes[0], k);
   ASSERT_EQ(a->sizes[1], m);
   ASSERT_EQ(a->strides[0], m + 1);

>From c62ca9d66c6142bdd8666235631893ab200e7cc5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 13 Aug 2025 13:38:57 +0200
Subject: [PATCH 15/22] [mlir][DialectUtils] Fix div by zero crash (#153380)

---
 mlir/lib/Dialect/Utils/StaticValueUtils.cpp |  2 +-
 mlir/test/Dialect/SCF/canonicalize.mlir     | 13 +++++++++++++
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 1cded38c4419e..36059e553d30e 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -272,7 +272,7 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
   if (!ubConstant)
     return std::nullopt;
   std::optional<int64_t> stepConstant = getConstantIntValue(step);
-  if (!stepConstant)
+  if (!stepConstant || *stepConstant == 0)
     return std::nullopt;
 
   return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8ba8013d008a0..b15fabdd29c61 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1925,3 +1925,16 @@ func.func @index_switch_fold_no_res() {
 
 // CHECK-LABEL: func.func @index_switch_fold_no_res()
 //  CHECK-NEXT: "test.op"() : () -> ()
+
+// -----
+
+// CHECK-LABEL: func @scf_for_all_step_size_0()
+//       CHECK:   scf.forall (%{{.*}}) = (0) to (1) step (0)
+func.func @scf_for_all_step_size_0()  {
+  %x = arith.constant 0 : index
+  scf.forall (%i, %j) = (0, 4) to (1, 5) step (%x, 8) {
+    vector.print %x : index
+    scf.forall.in_parallel {}
+  }
+  return
+}

>From e316672a113644b5b9cecc1d49ba6ee15843c7e3 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 18 Aug 2025 10:59:43 +0200
Subject: [PATCH 16/22] [MLIR] Erase unreachable blocks before applying
 patterns in the greedy rewriter (#153957)

Operations like:

    %add = arith.addi %add, %add : i64

are legal in unreachable code. Unfortunately many patterns would be
unsafe to apply on such IR and can lead to crashes or infinite loops. To
avoid this we can remove unreachable blocks before attempting to apply
patterns.
We may have to do this also whenever the CFG is changed by a pattern, it
is left up for future work right now.

Fixes #153732
---
 .../Utils/GreedyPatternRewriteDriver.cpp         | 13 ++++++++++++-
 mlir/test/Dialect/Arith/canonicalize.mlir        | 15 +++++++++++++++
 mlir/test/Transforms/test-canonicalize.mlir      | 16 +++++++---------
 3 files changed, 34 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 607b86cb86315..0a2a0cc1d5c73 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -871,7 +871,18 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
 
     ctx->executeAction<GreedyPatternRewriteIteration>(
         [&] {
-          continueRewrites = processWorklist();
+          continueRewrites = false;
+
+          // Erase unreachable blocks
+          // Operations like:
+          //   %add = arith.addi %add, %add : i64
+          // are legal in unreachable code. Unfortunately many patterns would be
+          // unsafe to apply on such IR and can lead to crashes or infinite
+          // loops.
+          continueRewrites |=
+              succeeded(eraseUnreachableBlocks(rewriter, region));
+
+          continueRewrites |= processWorklist();
 
           // After applying patterns, make sure that the CFG of each of the
           // regions is kept up to date.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 3d5a46d13e59d..cf570beb14b9e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -3363,3 +3363,18 @@ func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>,
     }
   }
 #-}
+
+// CHECK-LABEL: func @unreachable()
+// CHECK-NEXT: return
+// CHECK-NOT: arith
+func.func @unreachable() {
+  return
+^unreachable:
+  %c1_i64 = arith.constant 1 : i64
+  // This self referencing operation is legal in an unreachable block.
+  // Many patterns are unsafe with respect to this kind of situation,
+  // check that we don't infinite loop here.
+  %add = arith.addi %add, %c1_i64 : i64
+  cf.br ^unreachable
+}
+
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 0fc822b0a23ae..8cad6b98441df 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s  --check-prefixes=CHECK,RS
 // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=disabled}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
 
 // CHECK-LABEL: func @remove_op_with_inner_ops_pattern
@@ -80,12 +80,10 @@ func.func @test_dialect_canonicalizer() -> (i32) {
 
 // Check that the option to control region simplification actually works
 // CHECK-LABEL: test_region_simplify
-func.func @test_region_simplify() {
-  // CHECK-NEXT:   return
-  // NO-RS-NEXT: ^bb1
-  // NO-RS-NEXT:   return
-  // CHECK-NEXT: }
-  return
-^bb1:
-  return
+func.func @test_region_simplify(%input1 : i32, %cond : i1) -> i32 {
+  // RS-NEXT: "test.br"(%arg0)[^bb1] : (i32) -> ()
+  // NO-RS-NEXT: "test.br"(%arg0, %arg0)[^bb1] : (i32, i32) -> ()
+   "test.br"(%input1, %input1)[^bb1] : (i32, i32) -> ()
+^bb1(%used_arg : i32, %unused_arg : i32):
+  return %used_arg : i32
 }

>From a2339cd3fd00479de6d0341703cbe7d777a8423e Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 18 Aug 2025 11:07:19 +0200
Subject: [PATCH 17/22] [MLIR] Refactor the walkAndApplyPatterns driver to
 remove the recursion (#154037)

This is in preparation of a follow-up change to stop traversing
unreachable blocks.

This is not NFC because of a subtlety of the early_inc. On a test case
like:

```
  scf.if %cond {
    "test.move_after_parent_op"() ({
      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
    }) : () -> ()
  }
```

We recursively traverse the nested regions, and process an op when the
region is done (post-order).
We need to pre-increment the iterator before processing an operation in
case it gets deleted. However
we can do this before or after processing the nested region. This
implementation does the latter.
---
 .../Utils/WalkPatternRewriteDriver.cpp        | 99 +++++++++++++++++--
 .../IR/test-walk-pattern-rewrite-driver.mlir  |  4 +-
 2 files changed, 91 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642c943c4..2111e29120567 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,12 +13,14 @@
 #include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
 #include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "walk-rewriter"
@@ -88,20 +90,97 @@ void walkAndApplyPatterns(Operation *op,
   PatternApplicator applicator(patterns);
   applicator.applyDefaultCostModel();
 
+  // Iterator on all reachable operations in the region.
+  // Also keep track if we visited the nested regions of the current op
+  // already to drive the post-order traversal.
+  struct RegionReachableOpIterator {
+    RegionReachableOpIterator(Region *region) : region(region) {
+      regionIt = region->begin();
+      if (regionIt != region->end())
+        blockIt = regionIt->begin();
+    }
+    // Advance the iterator to the next reachable operation.
+    void advance() {
+      assert(regionIt != region->end());
+      hasVisitedRegions = false;
+      if (blockIt == regionIt->end()) {
+        ++regionIt;
+        if (regionIt != region->end())
+          blockIt = regionIt->begin();
+        return;
+      }
+      ++blockIt;
+      if (blockIt != regionIt->end()) {
+        LDBG() << "Incrementing block iterator, next op: "
+               << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+      }
+    }
+    // The region we're iterating over.
+    Region *region;
+    // The Block currently being iterated over.
+    Region::iterator regionIt;
+    // The Operation currently being iterated over.
+    Block::iterator blockIt;
+    // Whether we've visited the nested regions of the current op already.
+    bool hasVisitedRegions = false;
+  };
+
+  // Worklist of regions to visit to drive the post-order traversal.
+  SmallVector<RegionReachableOpIterator> worklist;
+
+  LDBG() << "Starting walk-based pattern rewrite driver";
   ctx->executeAction<WalkAndApplyPatternsAction>(
       [&] {
+        // Perform a post-order traversal of the regions, visiting each
+        // reachable operation.
         for (Region &region : op->getRegions()) {
-          region.walk([&](Operation *visitedOp) {
-            LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
-                llvm::dbgs(), OpPrintingFlags().skipRegions());
-                       llvm::dbgs() << "\n";);
+          assert(worklist.empty());
+          if (region.empty())
+            continue;
+
+          // Prime the worklist with the entry block of this region.
+          worklist.push_back({&region});
+          while (!worklist.empty()) {
+            RegionReachableOpIterator &it = worklist.back();
+            if (it.regionIt == it.region->end()) {
+              // We're done with this region.
+              worklist.pop_back();
+              continue;
+            }
+            if (it.blockIt == it.regionIt->end()) {
+              // We're done with this block.
+              it.advance();
+              continue;
+            }
+            Operation *op = &*it.blockIt;
+            // If we haven't visited the nested regions of this op yet,
+            // enqueue them.
+            if (!it.hasVisitedRegions) {
+              it.hasVisitedRegions = true;
+              for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+                if (nestedRegion.empty())
+                  continue;
+                worklist.push_back({&nestedRegion});
+              }
+            }
+            // If we're not at the back of the worklist, we've enqueued some
+            // nested region for processing. We'll come back to this op later
+            // (post-order)
+            if (&it != &worklist.back())
+              continue;
+
+            // Preemptively increment the iterator, in case the current op
+            // would be erased.
+            it.advance();
+
+            LDBG() << "Visiting op: "
+                   << OpWithFlags(op, OpPrintingFlags().skipRegions());
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-            erasedListener.visitedOp = visitedOp;
+            erasedListener.visitedOp = op;
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-            if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
-              LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
-            }
-          });
+            if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+              LDBG() << "\tOp matched and rewritten";
+          }
         }
       },
       {op});
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index 02f7e60671c9b..c75c478ec3734 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
 }
 
 // Check that the driver handles rewriter.moveAfter. In this case, we expect
-// the moved op to be visited only once since walk uses `make_early_inc_range`.
+// the moved op to be visited twice.
 // CHECK-LABEL: func.func @move_after(
 // CHECK: scf.if
 // CHECK: }
 // CHECK: "test.move_after_parent_op"
-// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
+// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
 // CHECK: return
 func.func @move_after(%cond : i1) {
   scf.if %cond {

>From b6cd6f2ed67c85137846e63862158c86a8c956aa Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 18 Aug 2025 12:48:55 +0200
Subject: [PATCH 18/22] [MLIR] Fix SCF verifier crash (#153974)

An operand of the nested yield op can be null and hasn't been verified
yet when processing the enclosing operation. Using `getResultTypes()`
will dereference this null Value and crash in the verifier.
---
 mlir/lib/Dialect/SCF/IR/SCF.cpp | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 00c31a1500e17..dbea050b554ea 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4273,14 +4273,15 @@ LogicalResult scf::IndexSwitchOp::verify() {
              << "see yield operation here";
     }
     for (auto [idx, result, operand] :
-         llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
-                   yield.getOperandTypes())) {
-      if (result == operand)
+         llvm::enumerate(getResultTypes(), yield.getOperands())) {
+      if (!operand)
+        return yield.emitOpError() << "operand " << idx << " is null\n";
+      if (result == operand.getType())
         continue;
       return (emitOpError("expected result #")
               << idx << " of each region to be " << result)
                  .attachNote(yield.getLoc())
-             << name << " returns " << operand << " here";
+             << name << " returns " << operand.getType() << " here";
     }
     return success();
   };

>From 8fd70f7885a77410170523f32da9a0a0a97329dd Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 18 Aug 2025 22:46:59 +0200
Subject: [PATCH 19/22] [MLIR] Stop visiting unreachable blocks in the
 walkAndApplyPatterns driver (#154038)

This is similar to the fix to the greedy driver in #153957 ; except that
instead of removing unreachable code, we just ignore it.

Operations like:

```
%add = arith.addi %add, %add : i64
```

are legal in unreachable code.
Unfortunately many patterns would be unsafe to apply on such IR and can
lead to crashes or infinite loops.
---
 .../Transforms/WalkPatternRewriteDriver.h     |  2 +
 .../Utils/WalkPatternRewriteDriver.cpp        | 49 ++++++++++++++++---
 .../IR/test-walk-pattern-rewrite-driver.mlir  | 20 ++++++++
 3 files changed, 64 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
index 6d62ae3dd43dc..7d5c1d5cebb26 100644
--- a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
@@ -27,6 +27,8 @@ namespace mlir {
 /// This is intended as the simplest and most lightweight pattern rewriter in
 /// cases when a simple walk gets the job done.
 ///
+/// The driver will skip unreachable blocks.
+///
 /// Note: Does not apply patterns to the given operation itself.
 void walkAndApplyPatterns(Operation *op,
                           const FrozenRewritePatternSet &patterns,
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index 2111e29120567..baa76b9aab4e5 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -20,13 +20,33 @@
 #include "mlir/IR/Visitors.h"
 #include "mlir/Rewrite/PatternApplicator.h"
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 
 #define DEBUG_TYPE "walk-rewriter"
 
 namespace mlir {
 
+// Find all reachable blocks in the region and add them to the visitedBlocks
+// set.
+static void findReachableBlocks(Region &region,
+                                DenseSet<Block *> &reachableBlocks) {
+  Block *entryBlock = &region.front();
+  reachableBlocks.insert(entryBlock);
+  // Traverse the CFG and add all reachable blocks to the blockList.
+  SmallVector<Block *> worklist({entryBlock});
+  while (!worklist.empty()) {
+    Block *block = worklist.pop_back_val();
+    Operation *terminator = &block->back();
+    for (Block *successor : terminator->getSuccessors()) {
+      if (reachableBlocks.contains(successor))
+        continue;
+      worklist.push_back(successor);
+      reachableBlocks.insert(successor);
+    }
+  }
+}
+
 namespace {
 struct WalkAndApplyPatternsAction final
     : tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -98,6 +118,8 @@ void walkAndApplyPatterns(Operation *op,
       regionIt = region->begin();
       if (regionIt != region->end())
         blockIt = regionIt->begin();
+      if (!llvm::hasSingleElement(*region))
+        findReachableBlocks(*region, reachableBlocks);
     }
     // Advance the iterator to the next reachable operation.
     void advance() {
@@ -105,14 +127,21 @@ void walkAndApplyPatterns(Operation *op,
       hasVisitedRegions = false;
       if (blockIt == regionIt->end()) {
         ++regionIt;
+        while (regionIt != region->end() &&
+               !reachableBlocks.contains(&*regionIt))
+          ++regionIt;
         if (regionIt != region->end())
           blockIt = regionIt->begin();
         return;
       }
       ++blockIt;
       if (blockIt != regionIt->end()) {
-        LDBG() << "Incrementing block iterator, next op: "
-               << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+        LLVM_DEBUG({
+          llvm::dbgs() << "Incrementing block iterator, next op: "
+                       << OpWithFlags(&*blockIt,
+                                      OpPrintingFlags().skipRegions())
+                       << "\n";
+        });
       }
     }
     // The region we're iterating over.
@@ -121,6 +150,8 @@ void walkAndApplyPatterns(Operation *op,
     Region::iterator regionIt;
     // The Operation currently being iterated over.
     Block::iterator blockIt;
+    // The set of blocks that are reachable in the current region.
+    DenseSet<Block *> reachableBlocks;
     // Whether we've visited the nested regions of the current op already.
     bool hasVisitedRegions = false;
   };
@@ -128,7 +159,8 @@ void walkAndApplyPatterns(Operation *op,
   // Worklist of regions to visit to drive the post-order traversal.
   SmallVector<RegionReachableOpIterator> worklist;
 
-  LDBG() << "Starting walk-based pattern rewrite driver";
+  LLVM_DEBUG(
+      { llvm::dbgs() << "Starting walk-based pattern rewrite driver\n"; });
   ctx->executeAction<WalkAndApplyPatternsAction>(
       [&] {
         // Perform a post-order traversal of the regions, visiting each
@@ -173,13 +205,16 @@ void walkAndApplyPatterns(Operation *op,
             // would be erased.
             it.advance();
 
-            LDBG() << "Visiting op: "
-                   << OpWithFlags(op, OpPrintingFlags().skipRegions());
+            LLVM_DEBUG({
+              llvm::dbgs() << "Visiting op: "
+                           << OpWithFlags(op, OpPrintingFlags().skipRegions())
+                           << "\n";
+            });
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
             erasedListener.visitedOp = op;
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
             if (succeeded(applicator.matchAndRewrite(op, rewriter)))
-              LDBG() << "\tOp matched and rewritten";
+              LLVM_DEBUG({ llvm::dbgs() << "\tOp matched and rewritten\n"; });
           }
         }
       },
diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
index c75c478ec3734..c3063416b0360 100644
--- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
+++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
@@ -119,3 +119,23 @@ func.func @erase_nested_block() -> i32 {
   }): () -> (i32)
   return %a : i32
 }
+
+
+// CHECK-LABEL: func.func @unreachable_replace_with_new_op
+// CHECK: "test.new_op"
+// CHECK: "test.replace_with_new_op"
+// CHECK-SAME: unreachable
+// CHECK: "test.new_op"
+func.func @unreachable_replace_with_new_op() {
+  "test.br"()[^bb1] : () -> ()
+^bb1:
+  %a = "test.replace_with_new_op"() : () -> (i32)
+  "test.br"()[^end] : () -> () // Test jumping over the unreachable block is visited as well.
+^unreachable:
+  %b = "test.replace_with_new_op"() {test.unreachable} : () -> (i32)
+  return
+^end:
+  %c = "test.replace_with_new_op"() : () -> (i32)
+  return
+}
+

>From b8572edd397e677363f494526fb1212b26c6ff2a Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 18 Aug 2025 22:50:36 +0200
Subject: [PATCH 20/22] [MLIR] Fix Liveness analysis handling of unreachable
 code (#153973)

This patch is forcing all values to be initialized by the
LivenessAnalysis, even in dead blocks. The dataflow framework will skip
visiting values when its already knows that a block is dynamically
unreachable, so this requires specific handling.
Downstream code could consider that the absence of liveness is the same
a "dead".
However as the code is mutated, new value can be introduced, and a
transformation like "RemoveDeadValue" must conservatively consider that
the absence of liveness information meant that we weren't sure if a
value was dead (it could be a newly introduced value.

Fixes #153906
---
 .../Analysis/DataFlow/LivenessAnalysis.cpp    |  40 ++++++-
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp |  96 ++++++++++++++-
 mlir/lib/Transforms/RemoveDeadValues.cpp      | 109 ++++++++++++++++--
 .../DataFlow/test-liveness-analysis.mlir      |  20 ++++
 mlir/test/Transforms/remove-dead-values.mlir  |  21 ++++
 .../DataFlow/TestLivenessAnalysis.cpp         |   1 -
 6 files changed, 272 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 6a12fe3acc2c2..e1d7498f7be35 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -295,9 +295,45 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
 
   loadBaselineAnalyses(solver);
   solver.load<LivenessAnalysis>(symbolTable);
-  LDBG("Initializing and running solver");
+  LLVM_DEBUG({ llvm::dbgs() << "Initializing and running solver\n"; });
   (void)solver.initializeAndRun(op);
-  LDBG("Dumping liveness state for op");
+  LLVM_DEBUG({
+    llvm::dbgs() << "RunLivenessAnalysis initialized for op: " << op->getName()
+                 << " check on unreachable code now:"
+                 << "\n";
+  });
+  // The framework doesn't visit operations in dead blocks, so we need to
+  // explicitly mark them as dead.
+  op->walk([&](Operation *op) {
+    if (op->getNumResults() == 0)
+      return;
+    for (auto result : llvm::enumerate(op->getResults())) {
+      if (getLiveness(result.value()))
+        continue;
+      LLVM_DEBUG({
+        llvm::dbgs() << "Result: " << result.index() << " of "
+                     << OpWithFlags(op, OpPrintingFlags().skipRegions())
+                     << " has no liveness info (unreachable), mark dead"
+                     << "\n";
+      });
+      solver.getOrCreateState<Liveness>(result.value());
+    }
+    for (auto &region : op->getRegions()) {
+      for (auto &block : region) {
+        for (auto blockArg : llvm::enumerate(block.getArguments())) {
+          if (getLiveness(blockArg.value()))
+            continue;
+          LLVM_DEBUG({
+            llvm::dbgs() << "Block argument: " << blockArg.index() << " of "
+                         << OpWithFlags(op, OpPrintingFlags().skipRegions())
+                         << " has no liveness info, mark dead"
+                         << "\n";
+          });
+          solver.getOrCreateState<Liveness>(blockArg.value());
+        }
+      }
+    }
+  });
 }
 
 const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index e625f626d12fd..5e342fd87773e 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -19,12 +19,15 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
 #include <cassert>
 #include <optional>
 
 using namespace mlir;
 using namespace mlir::dataflow;
 
+#define DEBUG_TYPE "dataflow"
+
 //===----------------------------------------------------------------------===//
 // AbstractSparseLattice
 //===----------------------------------------------------------------------===//
@@ -64,22 +67,56 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
 
 LogicalResult
 AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "Initializing recursively for operation: " << op->getName()
+                 << "\n";
+  });
+
   // Initialize the analysis by visiting every owner of an SSA value (all
   // operations and blocks).
-  if (failed(visitOperation(op)))
+  if (failed(visitOperation(op))) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Failed to visit operation: " << op->getName() << "\n";
+    });
     return failure();
+  }
 
   for (Region &region : op->getRegions()) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Processing region with " << region.getBlocks().size()
+                   << " blocks"
+                   << "\n";
+    });
     for (Block &block : region) {
+      LLVM_DEBUG({
+        llvm::dbgs() << "Processing block with " << block.getNumArguments()
+                     << " arguments"
+                     << "\n";
+      });
       getOrCreate<Executable>(getProgramPointBefore(&block))
           ->blockContentSubscribe(this);
       visitBlock(&block);
-      for (Operation &op : block)
-        if (failed(initializeRecursively(&op)))
+      for (Operation &op : block) {
+        LLVM_DEBUG({
+          llvm::dbgs() << "Recursively initializing nested operation: "
+                       << op.getName() << "\n";
+        });
+        if (failed(initializeRecursively(&op))) {
+          LLVM_DEBUG({
+            llvm::dbgs() << "Failed to initialize nested operation: "
+                         << op.getName() << "\n";
+          });
           return failure();
+        }
+      }
     }
   }
 
+  LLVM_DEBUG({
+    llvm::dbgs()
+        << "Successfully completed recursive initialization for operation: "
+        << op->getName() << "\n";
+  });
   return success();
 }
 
@@ -409,11 +446,29 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
 
 LogicalResult
 AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
+  LLVM_DEBUG({
+    llvm::dbgs() << "Visiting operation: " << op->getName() << " with "
+                 << op->getNumOperands() << " operands and "
+                 << op->getNumResults() << " results"
+                 << "\n";
+  });
+
   // If we're in a dead block, bail out.
   if (op->getBlock() != nullptr &&
-      !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
+      !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
+           ->isLive()) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Operation is in dead block, bailing out"
+                   << "\n";
+    });
     return success();
+  }
 
+  LLVM_DEBUG({
+    llvm::dbgs() << "Creating lattice elements for " << op->getNumOperands()
+                 << " operands and " << op->getNumResults() << " results"
+                 << "\n";
+  });
   SmallVector<AbstractSparseLattice *> operandLattices =
       getLatticeElements(op->getOperands());
   SmallVector<const AbstractSparseLattice *> resultLattices =
@@ -422,11 +477,21 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // Block arguments of region branch operations flow back into the operands
   // of the parent op
   if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Processing RegionBranchOpInterface operation"
+                   << "\n";
+    });
     visitRegionSuccessors(branch, operandLattices);
     return success();
   }
 
   if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Processing BranchOpInterface operation with "
+                   << op->getNumSuccessors() << " successors"
+                   << "\n";
+    });
+
     // Block arguments of successor blocks flow back into our operands.
 
     // We remember all operands not forwarded to any block in a BitVector.
@@ -463,6 +528,10 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // For function calls, connect the arguments of the entry blocks to the
   // operands of the call op that are forwarded to these arguments.
   if (auto call = dyn_cast<CallOpInterface>(op)) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Processing CallOpInterface operation"
+                   << "\n";
+    });
     Operation *callableOp = call.resolveCallableInTable(&symbolTable);
     if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
       // Not all operands of a call op forward to arguments. Such operands are
@@ -513,6 +582,10 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // of this op itself and the operands of the terminators of the regions of
   // this op.
   if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Processing RegionBranchTerminatorOpInterface operation"
+                   << "\n";
+    });
     if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
       visitRegionSuccessorsFromTerminator(terminator, branch);
       return success();
@@ -520,12 +593,25 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   }
 
   if (op->hasTrait<OpTrait::ReturnLike>()) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Processing ReturnLike operation"
+                   << "\n";
+    });
     // Going backwards, the operands of the return are derived from the
     // results of all CallOps calling this CallableOp.
-    if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+    if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
+      LLVM_DEBUG({
+        llvm::dbgs() << "Callable parent found, visiting callable operation"
+                     << "\n";
+      });
       return visitCallableOperation(op, callable, operandLattices);
+    }
   }
 
+  LLVM_DEBUG({
+    llvm::dbgs() << "Using default visitOperationImpl for operation: "
+                 << op->getName() << "\n";
+  });
   return visitOperationImpl(op, operandLattices, resultLattices);
 }
 
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index ddd5f2ba1a7b7..1d7e2135e23e1 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -258,16 +258,22 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             RDVFinalCleanupList &cl) {
-  LDBG("Processing simple op: " << *op);
   if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
-    LDBG("Simple op is not memory effect free or has live results, skipping: "
-         << *op);
+    LLVM_DEBUG({
+      llvm::dbgs()
+          << "Simple op is not memory effect free or has live results, "
+             "preserving it: "
+          << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n";
+    });
     return;
   }
 
-  LDBG("Simple op has all dead results and is memory effect free, scheduling "
-       "for removal: "
-       << *op);
+  LLVM_DEBUG({
+    llvm::dbgs() << "Simple op has all dead results and is memory effect free, "
+                    "scheduling "
+                    "for removal: "
+                 << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n";
+  });
   cl.operations.push_back(op);
   collectNonLiveValues(nonLiveSet, op->getResults(),
                        BitVector(op->getNumResults(), true));
@@ -727,19 +733,53 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
 /// Removes dead values collected in RDVFinalCleanupList.
 /// To be run once when all dead values have been collected.
 static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+  LLVM_DEBUG({ llvm::dbgs() << "Starting cleanup of dead values...\n"; });
+
   // 1. Operations
+  LLVM_DEBUG({
+    llvm::dbgs() << "Cleaning up " << list.operations.size() << " operations"
+                 << "\n";
+  });
   for (auto &op : list.operations) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Erasing operation: "
+                   << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n";
+    });
     op->dropAllUses();
     op->erase();
   }
 
   // 2. Values
+  LLVM_DEBUG({
+    llvm::dbgs() << "Cleaning up " << list.values.size() << " values"
+                 << "\n";
+  });
   for (auto &v : list.values) {
+    LLVM_DEBUG(
+        { llvm::dbgs() << "Dropping all uses of value: " << v << "\n"; });
     v.dropAllUses();
   }
 
   // 3. Functions
+  LLVM_DEBUG({
+    llvm::dbgs() << "Cleaning up " << list.functions.size() << " functions"
+                 << "\n";
+  });
   for (auto &f : list.functions) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Cleaning up function: "
+                   << f.funcOp.getOperation()->getName() << "\n";
+    });
+    LLVM_DEBUG({
+      llvm::dbgs() << "  Erasing " << f.nonLiveArgs.count()
+                   << " non-live arguments"
+                   << "\n";
+    });
+    LLVM_DEBUG({
+      llvm::dbgs() << "  Erasing " << f.nonLiveRets.count()
+                   << " non-live return values"
+                   << "\n";
+    });
     // Some functions may not allow erasing arguments or results. These calls
     // return failure in such cases without modifying the function, so it's okay
     // to proceed.
@@ -748,44 +788,99 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
   }
 
   // 4. Operands
+  LLVM_DEBUG({
+    llvm::dbgs() << "Cleaning up " << list.operands.size() << " operand lists"
+                 << "\n";
+  });
   for (OperationToCleanup &o : list.operands) {
-    if (o.op->getNumOperands() > 0)
+    if (o.op->getNumOperands() > 0) {
+      LLVM_DEBUG({
+        llvm::dbgs() << "Erasing " << o.nonLive.count()
+                     << " non-live operands from operation: "
+                     << OpWithFlags(o.op, OpPrintingFlags().skipRegions())
+                     << "\n";
+      });
       o.op->eraseOperands(o.nonLive);
+    }
   }
 
   // 5. Results
+  LLVM_DEBUG({
+    llvm::dbgs() << "Cleaning up " << list.results.size() << " result lists"
+                 << "\n";
+  });
   for (auto &r : list.results) {
+    LLVM_DEBUG({
+      llvm::dbgs() << "Erasing " << r.nonLive.count()
+                   << " non-live results from operation: "
+                   << OpWithFlags(r.op, OpPrintingFlags().skipRegions())
+                   << "\n";
+    });
     dropUsesAndEraseResults(r.op, r.nonLive);
   }
 
   // 6. Blocks
+  LLVM_DEBUG({
+    llvm::dbgs() << "Cleaning up " << list.blocks.size()
+                 << " block argument lists"
+                 << "\n";
+  });
   for (auto &b : list.blocks) {
     // blocks that are accessed via multiple codepaths processed once
     if (b.b->getNumArguments() != b.nonLiveArgs.size())
       continue;
+    LLVM_DEBUG({
+      llvm::dbgs() << "Erasing " << b.nonLiveArgs.count()
+                   << " non-live arguments from block: " << b.b << "\n";
+    });
     // it iterates backwards because erase invalidates all successor indexes
     for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
       if (!b.nonLiveArgs[i])
         continue;
+      LLVM_DEBUG({
+        llvm::dbgs() << "  Erasing block argument " << i << ": "
+                     << b.b->getArgument(i) << "\n";
+      });
       b.b->getArgument(i).dropAllUses();
       b.b->eraseArgument(i);
     }
   }
 
   // 7. Successor Operands
+  LLVM_DEBUG({
+    llvm::dbgs() << "Cleaning up " << list.successorOperands.size()
+                 << " successor operand lists"
+                 << "\n";
+  });
   for (auto &op : list.successorOperands) {
     SuccessorOperands successorOperands =
         op.branch.getSuccessorOperands(op.successorIndex);
     // blocks that are accessed via multiple codepaths processed once
     if (successorOperands.size() != op.nonLiveOperands.size())
       continue;
+    LLVM_DEBUG({
+      llvm::dbgs() << "Erasing " << op.nonLiveOperands.count()
+                   << " non-live successor operands from successor "
+                   << op.successorIndex << " of branch: "
+                   << OpWithFlags(op.branch, OpPrintingFlags().skipRegions())
+                   << "\n";
+    });
     // it iterates backwards because erase invalidates all successor indexes
     for (int i = successorOperands.size() - 1; i >= 0; --i) {
       if (!op.nonLiveOperands[i])
         continue;
+      LLVM_DEBUG({
+        llvm::dbgs() << "  Erasing successor operand " << i << ": "
+                     << successorOperands[i] << "\n";
+      });
       successorOperands.erase(i);
     }
   }
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "Finished cleanup of dead values"
+                 << "\n";
+  });
 }
 
 struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
index a89a0f4084e99..3748be74eb0f3 100644
--- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
+++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir
@@ -283,3 +283,23 @@ func.func @test_10_negative() -> (i32) {
   %0:2 = func.call @private_1() : () -> (i32, i32)
   return %0#0 : i32
 }
+
+// -----
+
+// Test that we correctly set a liveness value for operations in dead block.
+// These won't be visited by the dataflow framework so the analysis need to
+// explicitly manage them.
+// CHECK-LABEL: test_tag: dead_block_cmpi:
+// CHECK-NEXT: operand #0: not live
+// CHECK-NEXT: operand #1: not live
+// CHECK-NEXT: result #0: not live
+func.func @dead_block() {
+  %false = arith.constant false
+  %zero = arith.constant 0 : i64
+  cf.cond_br %false, ^bb1, ^bb4
+  ^bb1:
+    %3 = arith.cmpi eq, %zero, %zero  {tag = "dead_block_cmpi"} : i64
+    cf.br ^bb1
+  ^bb4:
+    return
+}
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 9ded6a30d9c95..0f8d757086e87 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -571,3 +571,24 @@ module @return_void_with_unused_argument {
   }
 }
 
+// -----
+
+// CHECK-LABEL: module @dynamically_unreachable
+module @dynamically_unreachable {
+  func.func @dynamically_unreachable() {
+    // This value is used by an operation in a dynamically unreachable block.
+    %zero = arith.constant 0 : i64
+
+    // Dataflow analysis knows from the constant condition that
+    // ^bb1 is unreachable
+    %false = arith.constant false
+    cf.cond_br %false, ^bb1, ^bb4
+  ^bb1:
+    // This unreachable operation should be removed.
+    // CHECK-NOT: arith.cmpi
+    %3 = arith.cmpi eq, %zero, %zero : i64
+    cf.br ^bb1
+  ^bb4:
+    return
+  }
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
index 43005e22584c2..8e2f03b644e49 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestLivenessAnalysis.cpp
@@ -33,7 +33,6 @@ struct TestLivenessAnalysisPass
 
   void runOnOperation() override {
     auto &livenessAnalysis = getAnalysis<RunLivenessAnalysis>();
-
     Operation *op = getOperation();
 
     raw_ostream &os = llvm::outs();

>From 9503308005b9768e7cac702805214f53bf4d09b3 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 1 Aug 2025 13:11:19 +0200
Subject: [PATCH 21/22] [MLIR] Introduce a OpWithState class to act as a stream
 modifier for Operations (NFC) (#151547)

On the model of OpWithFlags, this modifier allows to stream an operation
using a custom AsmPrinter.
---
 mlir/include/mlir/IR/Operation.h        | 43 +++++++++++++++++++++++++
 mlir/lib/Tools/mlir-opt/MlirOptMain.cpp |  3 +-
 2 files changed, 44 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 1c2c04e718bf7..11af3b7d4d7b6 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1102,6 +1102,49 @@ inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) {
   return os;
 }
 
+/// A wrapper class that allows for printing an operation with a set of flags,
+/// useful to act as a "stream modifier" to customize printing an operation
+/// with a stream using the operator<< overload, e.g.:
+///   llvm::dbgs() << OpWithFlags(op, OpPrintingFlags().skipRegions());
+class OpWithFlags {
+public:
+  OpWithFlags(Operation *op, OpPrintingFlags flags = {})
+      : op(op), theFlags(flags) {}
+  OpPrintingFlags &flags() { return theFlags; }
+  const OpPrintingFlags &flags() const { return theFlags; }
+
+private:
+  Operation *op;
+  OpPrintingFlags theFlags;
+  friend raw_ostream &operator<<(raw_ostream &os, const OpWithFlags &op);
+};
+
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const OpWithFlags &opWithFlags) {
+  opWithFlags.op->print(os, opWithFlags.flags());
+  return os;
+}
+
+/// A wrapper class that allows for printing an operation with a custom
+/// AsmState, useful to act as a "stream modifier" to customize printing an
+/// operation with a stream using the operator<< overload, e.g.:
+///   llvm::dbgs() << OpWithState(op, OpPrintingFlags().skipRegions());
+class OpWithState {
+public:
+  OpWithState(Operation *op, AsmState &state) : op(op), theState(state) {}
+
+private:
+  Operation *op;
+  AsmState &theState;
+  friend raw_ostream &operator<<(raw_ostream &os, const OpWithState &op);
+};
+
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const OpWithState &opWithState) {
+  opWithState.op->print(os, const_cast<OpWithState &>(opWithState).theState);
+  return os;
+}
+
 } // namespace mlir
 
 namespace llvm {
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 31e0caa768113..9f2a5c761b5ce 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -508,8 +508,7 @@ performActions(raw_ostream &os,
            << "bytecode version while not emitting bytecode";
   AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
                     &fallbackResourceMap);
-  op.get()->print(os, asmState);
-  os << '\n';
+  os << OpWithState(op.get(), asmState) << '\n';
   return success();
 }
 

>From f9a92b316af4cdf6a044b2f5cc6a17bcfa2994f3 Mon Sep 17 00:00:00 2001
From: Hank <49036880+hankluo6 at users.noreply.github.com>
Date: Wed, 20 Aug 2025 06:03:26 -0700
Subject: [PATCH 22/22] [MLIR] Fix duplicated attribute nodes in MLIR bytecode
 deserialization (#151267)

Fixes #150163

MLIR bytecode does not preserve alias definitions, so each attribute
encountered during deserialization is treated as a new one. This can
generate duplicate `DISubprogram` nodes during deserialization.

The patch adds a `StringMap` cache that records attributes and fetches
them when encountered again.
---
 mlir/include/mlir/AsmParser/AsmParser.h     |  3 ++-
 mlir/lib/AsmParser/DialectSymbolParser.cpp  | 24 +++++++++++++++++++--
 mlir/lib/AsmParser/ParserState.h            |  3 +++
 mlir/lib/Bytecode/Reader/BytecodeReader.cpp |  6 +++++-
 mlir/test/IR/recursive-distinct-attr.mlir   | 13 +++++++++++
 5 files changed, 45 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/IR/recursive-distinct-attr.mlir

diff --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h
index 33daf7ca26f49..f39b3bd853a2a 100644
--- a/mlir/include/mlir/AsmParser/AsmParser.h
+++ b/mlir/include/mlir/AsmParser/AsmParser.h
@@ -53,7 +53,8 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
 /// null terminated.
 Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
                          Type type = {}, size_t *numRead = nullptr,
-                         bool isKnownNullTerminated = false);
+                         bool isKnownNullTerminated = false,
+                         llvm::StringMap<Attribute> *attributesCache = nullptr);
 
 /// This parses a single MLIR type to an MLIR context if it was valid. If not,
 /// an error diagnostic is emitted to the context.
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 9f4a87a6a02de..bae845e9018c4 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -238,6 +238,15 @@ static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
       return nullptr;
   }
 
+  if constexpr (std::is_same_v<Symbol, Attribute>) {
+    auto &cache = p.getState().symbols.attributesCache;
+    auto cacheIt = cache.find(symbolData);
+    // Skip cached attribute if it has type.
+    if (cacheIt != cache.end() && !p.getToken().is(Token::colon))
+      return cacheIt->second;
+
+    return cache[symbolData] = createSymbol(dialectName, symbolData, loc);
+  }
   return createSymbol(dialectName, symbolData, loc);
 }
 
@@ -330,6 +339,7 @@ Type Parser::parseExtendedType() {
 template <typename T, typename ParserFn>
 static T parseSymbol(StringRef inputStr, MLIRContext *context,
                      size_t *numReadOut, bool isKnownNullTerminated,
+                     llvm::StringMap<Attribute> *attributesCache,
                      ParserFn &&parserFn) {
   // Set the buffer name to the string being parsed, so that it appears in error
   // diagnostics.
@@ -341,6 +351,9 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
   SourceMgr sourceMgr;
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
   SymbolState aliasState;
+  if (attributesCache)
+    aliasState.attributesCache = *attributesCache;
+
   ParserConfig config(context);
   ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
                     /*codeCompleteContext=*/nullptr);
@@ -351,6 +364,11 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
   if (!symbol)
     return T();
 
+  if constexpr (std::is_same_v<T, Attribute>) {
+    if (attributesCache)
+      *attributesCache = state.symbols.attributesCache;
+  }
+
   // Provide the number of bytes that were read.
   Token endTok = parser.getToken();
   size_t numRead =
@@ -367,13 +385,15 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
 
 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
                                Type type, size_t *numRead,
-                               bool isKnownNullTerminated) {
+                               bool isKnownNullTerminated,
+                               llvm::StringMap<Attribute> *attributesCache) {
   return parseSymbol<Attribute>(
-      attrStr, context, numRead, isKnownNullTerminated,
+      attrStr, context, numRead, isKnownNullTerminated, attributesCache,
       [type](Parser &parser) { return parser.parseAttribute(type); });
 }
 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
                      bool isKnownNullTerminated) {
   return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
+                           /*attributesCache=*/nullptr,
                            [](Parser &parser) { return parser.parseType(); });
 }
diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h
index 159058a18fa4e..aa53032107cbf 100644
--- a/mlir/lib/AsmParser/ParserState.h
+++ b/mlir/lib/AsmParser/ParserState.h
@@ -40,6 +40,9 @@ struct SymbolState {
 
   /// A map from unique integer identifier to DistinctAttr.
   DenseMap<uint64_t, DistinctAttr> distinctAttributes;
+
+  /// A map from unique string identifier to Attribute.
+  llvm::StringMap<Attribute> attributesCache;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 44458d010c6c8..0f97443433774 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -895,6 +895,10 @@ class AttrTypeReader {
   SmallVector<AttrEntry> attributes;
   SmallVector<TypeEntry> types;
 
+  /// The map of cached attributes, used to avoid re-parsing the same
+  /// attribute multiple times.
+  llvm::StringMap<Attribute> attributesCache;
+
   /// A location used for error emission.
   Location fileLoc;
 
@@ -1235,7 +1239,7 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
         ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
   else
     result = ::parseAttribute(asmStr, context, Type(), &numRead,
-                              /*isKnownNullTerminated=*/true);
+                              /*isKnownNullTerminated=*/true, &attributesCache);
   if (!result)
     return failure();
 
diff --git a/mlir/test/IR/recursive-distinct-attr.mlir b/mlir/test/IR/recursive-distinct-attr.mlir
new file mode 100644
index 0000000000000..5afb5c59e0fcf
--- /dev/null
+++ b/mlir/test/IR/recursive-distinct-attr.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt -emit-bytecode %s | mlir-opt --mlir-print-debuginfo | FileCheck %s
+
+// Verify that the distinct attribute which is used transitively
+// through two aliases does not end up duplicated when round-tripped
+// through bytecode.
+
+// CHECK: distinct[0]
+// CHECK-NOT: distinct[1]
+#attr_ugly = #test<attr_ugly begin distinct[0]<> end>
+#attr_ugly1 = #test<attr_ugly begin #attr_ugly end>
+
+module attributes {test.alias = #attr_ugly, test.alias1 = #attr_ugly1} {
+}
\ No newline at end of file



More information about the llvm-branch-commits mailing list