[Mlir-commits] [mlir] [mlir][amdgpu] Fuse adjacent `MemoryCounterWaitOp` (PR #171148)
Ivan Butygin
llvmlistbot at llvm.org
Mon Dec 8 07:14:38 PST 2025
https://github.com/Hardcode84 created https://github.com/llvm/llvm-project/pull/171148
Taking the minimum value.
>From a27147ee18be7288d119d626583a177f34a27131 Mon Sep 17 00:00:00 2001
From: Ivan Butygin <ivan.butygin at gmail.com>
Date: Mon, 8 Dec 2025 16:06:03 +0100
Subject: [PATCH] [mlir][amdgpu] Fuse adjacent `MemoryCounterWaitOp`
---
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 2 +
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 43 +++++++++++++++++++
mlir/test/Dialect/AMDGPU/canonicalize.mlir | 11 +++++
3 files changed, 56 insertions(+)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index ec9f449c35dc4..5eda52b7cc68d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -921,6 +921,8 @@ def AMDGPU_MemoryCounterWaitOp :
let assemblyFormat = [{
oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` | `exp` `(` $exp `)` ) attr-dict
}];
+
+ let hasCanonicalizer = 1;
}
def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB",
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cf74f671db216..0b2eb6827ceeb 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -596,6 +596,49 @@ LogicalResult PermlaneSwapOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// MemoryCounterWaitOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Fuse adjacent memory counter wait ops, taking the minimum value of the
+/// counters.
+struct FuseMemoryCounterWaitOp final : OpRewritePattern<MemoryCounterWaitOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
+ PatternRewriter &rewriter) const override {
+ auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
+ if (!next)
+ return failure();
+
+ auto setters = {&MemoryCounterWaitOp::setLoad,
+ &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
+ &MemoryCounterWaitOp::setExp};
+ auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp()};
+ auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
+ next.getExp()};
+ for (const auto &[setter, lhs, rhs] :
+ llvm::zip(setters, lhsVals, rhsVals)) {
+ if (lhs && rhs) {
+ (op.*setter)(std::min(*lhs, *rhs));
+ } else if (lhs) {
+ (op.*setter)(*lhs);
+ } else if (rhs) {
+ (op.*setter)(*rhs);
+ }
+ }
+ rewriter.eraseOp(next);
+ return success();
+ }
+};
+} // namespace
+
+void MemoryCounterWaitOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<FuseMemoryCounterWaitOp>(context);
+}
+
//===----------------------------------------------------------------------===//
// GatherToLDSOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index fee0c00606ab4..0a8243a3f02ad 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -244,3 +244,14 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
%res_7 = amdgpu.scaled_mfma 16x16x128 (%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
+
+// -----
+
+// CHECK-LABEL fuse_memory_counter_wait
+func.func @fuse_memory_counter_wait() {
+ // CHECK: amdgpu.memory_counter_wait load(1) store(2) ds(2) exp(1)
+ // CHECK-NEXT: return
+ amdgpu.memory_counter_wait load(1) store(2) ds(3) exp(4)
+ amdgpu.memory_counter_wait load(4) store(3) ds(2) exp(1)
+ return
+}
More information about the Mlir-commits
mailing list