[Mlir-commits] [mlir] [MLIR][GPU] Add a pattern to rewrite gpu.subgroup_id (PR #137671)

Alan Li llvmlistbot at llvm.org
Mon Apr 28 20:04:50 PDT 2025


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/137671

>From a19415fefaacb8d171711097f0bafa73510f600c Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 28 Apr 2025 12:27:06 -0400
Subject: [PATCH 1/3] [MLIR][GPU] Add a pattern to rewrite gpu.subgroup_id

This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id`
to:
```
subgroup_id = linearized_thread_id / gpu.subgroup_size
```

where:
```
linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z)
```
---
 .../mlir/Dialect/GPU/Transforms/Passes.h      |  5 ++
 mlir/lib/Dialect/GPU/CMakeLists.txt           |  1 +
 .../GPU/Transforms/SubgroupIdRewriter.cpp     | 82 +++++++++++++++++++
 mlir/test/Dialect/GPU/subgroupId-rewrite.mlir | 26 ++++++
 4 files changed, 114 insertions(+)
 create mode 100644 mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
 create mode 100644 mlir/test/Dialect/GPU/subgroupId-rewrite.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index a13ad33df29cd..cbb990e603a38 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -39,6 +39,10 @@ class FuncOp;
 /// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
 void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);
 
+/// Collect a set of patterns to rewrite SubgroupIdOp op within the GPU
+/// dialect.
+void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns);
+
 /// Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
 void populateGpuShufflePatterns(RewritePatternSet &patterns);
 
@@ -88,6 +92,7 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
   populateGpuGlobalIdPatterns(patterns);
   populateGpuShufflePatterns(patterns);
+  populateGpuSubgroupIdPatterns(patterns);
 }
 
 namespace gpu {
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index be6492a22f34f..e21fa501bae6b 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -40,6 +40,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
   Transforms/ROCDLAttachTarget.cpp
   Transforms/ShuffleRewriter.cpp
   Transforms/SPIRVAttachTarget.cpp
+  Transforms/SubgroupIdRewriter.cpp
   Transforms/SubgroupReduceLowering.cpp
 
   OBJECT
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
new file mode 100644
index 0000000000000..1c322c1016c01
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
@@ -0,0 +1,82 @@
+//===- SubgroupIdRewriter.cpp - Implementation of SugroupId rewriting  ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
+// where:
+// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
+  using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
+                                PatternRewriter &rewriter) const override {
+    // Calculation of the thread's subgroup identifier.
+    //
+    // The process involves mapping the thread's 3D identifier within its
+    // block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
+    // This linearization assumes a layout where the x-dimension (w_dim.x)
+    // varies most rapidly (i.e., it is the innermost dimension).
+    //
+    // The formula for the linearized thread index is:
+    // L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
+    //
+    // Subsequently, the range of linearized indices [0, N_threads-1] is
+    // divided into consecutive, non-overlapping segments, each representing
+    // a subgroup of size 'subgroup_size'.
+    //
+    // Example Partitioning (N = subgroup_size):
+    // | Subgroup 0      | Subgroup 1      | Subgroup 2      | ... |
+    // | Indices 0..N-1  | Indices N..2N-1 | Indices 2N..3N-1| ... |
+    //
+    // The subgroup identifier is obtained via integer division of the
+    // linearized thread index by the predefined 'subgroup_size'.
+    //
+    // subgroup_id = floor( L / subgroup_size )
+    //             = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
+    //             subgroup_size
+
+    auto loc = op->getLoc();
+
+    Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
+    Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
+    Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
+    Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
+    Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);
+
+    Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ);
+    Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
+    Value dimYxIdZPlusIdYTimesDimX =
+        rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
+    Value IdXPlusDimYxIdZPlusIdYTimesDimX =
+        rewriter.create<index::AddOp>(loc, tidX, dimYxIdZPlusIdYTimesDimX);
+    Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
+        loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
+    Value subgroupIdOp = rewriter.create<index::DivUOp>(
+        loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
+    rewriter.replaceOp(op, {subgroupIdOp});
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) {
+  patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
new file mode 100644
index 0000000000000..02fcb2ba21dad
--- /dev/null
+++ b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
+
+module {
+  // CHECK-LABEL: func.func @subgroupId
+  // CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
+  func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
+    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+               threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+      // CHECK: %[[DIMX:.*]] = gpu.block_dim  x
+      // CHECK-NEXT: %[[DIMY:.*]] = gpu.block_dim  y
+      // CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id  x
+      // CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id  y
+      // CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id  z
+      // CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
+      // CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
+      // CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
+      // CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
+      // CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
+      // CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
+      %idz = gpu.subgroup_id : index
+      memref.store %idz, %mem[] : memref<index, 1>
+      gpu.terminator
+    }
+    return
+  }
+}

>From fbe3bd189e6d26cb5f2bc68143238d1024d4aa7e Mon Sep 17 00:00:00 2001
From: Alan Li <alan.li at me.com>
Date: Mon, 28 Apr 2025 16:21:20 -0400
Subject: [PATCH 2/3] Update
 mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp

Co-authored-by: Copilot <175728472+Copilot at users.noreply.github.com>
---
 mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
index 1c322c1016c01..5fa095564c545 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
@@ -1,4 +1,4 @@
-//===- SubgroupIdRewriter.cpp - Implementation of SugroupId rewriting  ----===//
+//===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting  ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

>From e316a6e39a37796af44798a03bb3cc8b051b999b Mon Sep 17 00:00:00 2001
From: Alan Li <me at alanli.org>
Date: Mon, 28 Apr 2025 16:44:00 -0400
Subject: [PATCH 3/3] remove

---
 .../mlir/Dialect/GPU/Transforms/Passes.h      |  1 -
 .../GPU/Transforms/GlobalIdRewriter.cpp       |  2 +-
 .../GPU/Transforms/SubgroupIdRewriter.cpp     |  2 +-
 mlir/test/Dialect/GPU/subgroupId-rewrite.mlir | 42 +++++++++----------
 mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp  |  1 +
 5 files changed, 23 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index cbb990e603a38..6cd6f03253aea 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -92,7 +92,6 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
   populateGpuAllReducePatterns(patterns);
   populateGpuGlobalIdPatterns(patterns);
   populateGpuShufflePatterns(patterns);
-  populateGpuSubgroupIdPatterns(patterns);
 }
 
 namespace gpu {
diff --git a/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
index 0c730df73b519..c40ddd9b15afc 100644
--- a/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
@@ -26,7 +26,7 @@ struct GpuGlobalIdRewriter : public OpRewritePattern<gpu::GlobalIdOp> {
 
   LogicalResult matchAndRewrite(gpu::GlobalIdOp op,
                                 PatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
+    Location loc = op.getLoc();
     auto dim = op.getDimension();
     auto blockId = rewriter.create<gpu::BlockIdOp>(loc, dim);
     auto blockDim = rewriter.create<gpu::BlockDimOp>(loc, dim);
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
index 5fa095564c545..72099371c0700 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
@@ -1,4 +1,4 @@
-//===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting  ----===//
+//===- SubgroupIdRewriter.cpp - Implementation of SubgroupId rewriting ----===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
index 02fcb2ba21dad..a0c852f6fbe88 100644
--- a/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
+++ b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
@@ -1,26 +1,24 @@
 // RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
 
-module {
-  // CHECK-LABEL: func.func @subgroupId
-  // CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
-  func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
-    gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
-               threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
-      // CHECK: %[[DIMX:.*]] = gpu.block_dim  x
-      // CHECK-NEXT: %[[DIMY:.*]] = gpu.block_dim  y
-      // CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id  x
-      // CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id  y
-      // CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id  z
-      // CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
-      // CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
-      // CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
-      // CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
-      // CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
-      // CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
-      %idz = gpu.subgroup_id : index
-      memref.store %idz, %mem[] : memref<index, 1>
-      gpu.terminator
-    }
-    return
+// CHECK-LABEL: func.func @subgroupId
+// CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
+func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
+  gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
+             threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
+    // CHECK: %[[DIMX:.*]] = gpu.block_dim  x
+    // CHECK-NEXT: %[[DIMY:.*]] = gpu.block_dim  y
+    // CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id  x
+    // CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id  y
+    // CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id  z
+    // CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
+    // CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
+    // CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
+    // CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
+    // CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
+    // CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
+    %idz = gpu.subgroup_id : index
+    memref.store %idz, %mem[] : memref<index, 1>
+    gpu.terminator
   }
+  return
 }
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index fe402da4cc105..616f458e4824c 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -41,6 +41,7 @@ struct TestGpuRewritePass
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     populateGpuRewritePatterns(patterns);
+    populateGpuSubgroupIdPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };



More information about the Mlir-commits mailing list