[Mlir-commits] [mlir] [mlir][amdgpu] Fuse adjacent `MemoryCounterWaitOp` (PR #171148)

Ivan Butygin llvmlistbot at llvm.org
Mon Dec 8 07:38:51 PST 2025


https://github.com/Hardcode84 updated https://github.com/llvm/llvm-project/pull/171148

>From a27147ee18be7288d119d626583a177f34a27131 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 8 Dec 2025 16:06:03 +0100
Subject: [PATCH 1/3] [mlir][amdgpu] Fuse adjacent `MemoryCounterWaitOp`

---
 mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |  2 +
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp  | 43 +++++++++++++++++++
 mlir/test/Dialect/AMDGPU/canonicalize.mlir    | 11 +++++
 3 files changed, 56 insertions(+)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ec9f449c35dc4..5eda52b7cc68d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -921,6 +921,8 @@ def AMDGPU_MemoryCounterWaitOp :
   let assemblyFormat = [{
     oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` ) attr-dict
   }];
+
+  let hasCanonicalizer = 1;
 }
 
 def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB",
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cf74f671db216..0b2eb6827ceeb 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -596,6 +596,49 @@ LogicalResult PermlaneSwapOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MemoryCounterWaitOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Fuse adjacent memory counter wait ops, taking the minimum value of the
+/// counters.
+struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
+                                PatternRewriter &rewriter) const override {
+    auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
+    if (!next)
+      return failure();
+
+    auto setters = {&MemoryCounterWaitOp::setLoad,
+                    &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
+                    &MemoryCounterWaitOp::setExp};
+    auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp()};
+    auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
+                    next.getExp()};
+    for (const auto &[setter, lhs, rhs] :
+         llvm::zip(setters, lhsVals, rhsVals)) {
+      if (lhs && rhs) {
+        (op.*setter)(std::min(*lhs, *rhs));
+      } else if (lhs) {
+        (op.*setter)(*lhs);
+      } else if (rhs) {
+        (op.*setter)(*rhs);
+      }
+    }
+    rewriter.eraseOp(next);
+    return success();
+  }
+};
+} // namespace
+
+void MemoryCounterWaitOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<FuseMemoryCounterWaitOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // GatherToLDSOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index fee0c00606ab4..0a8243a3f02ad 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -244,3 +244,14 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
   %res_7 = amdgpu.scaled_mfma 16x16x128 (%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL fuse_memory_counter_wait
+func.func @fuse_memory_counter_wait() {
+  //      CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(2) exp(1)
+  // CHECK-NEXT: return
+  amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
+  amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
+  return
+}

>From 7698e0c043d6c29b98f1e9d764d566405a100af1 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 8 Dec 2025 16:23:46 +0100
Subject: [PATCH 2/3] review comments

---
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 22 +++++++++++---------
 mlir/test/Dialect/AMDGPU/canonicalize.mlir   | 14 +++++++++++++
 2 files changed, 26 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 0b2eb6827ceeb..d20e77ae094bf 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -604,7 +604,7 @@ namespace {
 /// Fuse adjacent memory counter wait ops, taking the minimum value of the
 /// counters.
 struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
-  using OpRewritePattern::OpRewritePattern;
+  using Base::Base;
 
   LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
                                 PatternRewriter &rewriter) const override {
@@ -618,16 +618,18 @@ struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
     auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp()};
     auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
                     next.getExp()};
-    for (const auto &[setter, lhs, rhs] :
-         llvm::zip(setters, lhsVals, rhsVals)) {
-      if (lhs && rhs) {
-        (op.*setter)(std::min(*lhs, *rhs));
-      } else if (lhs) {
-        (op.*setter)(*lhs);
-      } else if (rhs) {
-        (op.*setter)(*rhs);
+    rewriter.modifyOpInPlace(op, [&] {
+      for (auto [setter, lhs, rhs] :
+           llvm::zip_equal(setters, lhsVals, rhsVals)) {
+        if (lhs && rhs) {
+          (op.*setter)(std::min(*lhs, *rhs));
+        } else if (lhs) {
+          (op.*setter)(*lhs);
+        } else if (rhs) {
+          (op.*setter)(*rhs);
+        }
       }
-    }
+    });
     rewriter.eraseOp(next);
     return success();
   }
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 0a8243a3f02ad..b57b381ebb4f6 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -255,3 +255,17 @@ func.func @fuse_memory_counter_wait() {
   amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
   return
 }
+
+func.func private @use()
+
+// CHECK-LABEL fuse_memory_counter_wait_not_adjacent
+func.func @fuse_memory_counter_wait_not_adjacent() {
+  //      CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
+  // CHECK-NEXT: call @use()
+  // CHECK-NEXT: amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
+  // CHECK-NEXT: return
+  amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
+  func.call @use() : () -> ()
+  amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
+  return
+}

>From 1a94ec5ccea326f4b9abacc30cc56c1e1ae36f47 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 8 Dec 2025 16:32:09 +0100
Subject: [PATCH 3/3] tests

---
 mlir/test/Dialect/AMDGPU/canonicalize.mlir | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index b57b381ebb4f6..c66e9ed5d6f6d 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -249,13 +249,24 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
 
 // CHECK-LABEL fuse_memory_counter_wait
 func.func @fuse_memory_counter_wait() {
-  //      CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(2) exp(1)
+  //      CHECK: amdgpu.memory_counter_wait
+  // CHECK-SAME: load(1) store(2) ds(2) exp(1)
   // CHECK-NEXT: return
   amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
   amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
   return
 }
 
+// CHECK-LABEL fuse_memory_counter_wait_different_counters
+func.func @fuse_memory_counter_wait_different_counters() {
+  //      CHECK: amdgpu.memory_counter_wait
+  // CHECK-SAME: load(1) store(2) ds(3) exp(4)
+  // CHECK-NEXT: return
+  amdgpu.memory_counter_wait load(1) store(2)
+  amdgpu.memory_counter_wait ds(3) exp(4)
+  return
+}
+
 func.func private @use()
 
 // CHECK-LABEL fuse_memory_counter_wait_not_adjacent



More information about the Mlir-commits mailing list