[Mlir-commits] [mlir] ad8d9e1 - [mlir][gpu] Use `arith` dialect to lower gpu.global_id (#171614)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Dec 13 02:43:17 PST 2025


Author: Longsheng Mou
Date: 2025-12-13T18:43:12+08:00
New Revision: ad8d9e1428721f161c78e3334a8ecee0ebeb2487

URL: https://github.com/llvm/llvm-project/commit/ad8d9e1428721f161c78e3334a8ecee0ebeb2487
DIFF: https://github.com/llvm/llvm-project/commit/ad8d9e1428721f161c78e3334a8ecee0ebeb2487.diff

LOG: [mlir][gpu] Use `arith` dialect to lower gpu.global_id (#171614)

This PR lowers the`gpu.global_id` op using the arith dialect instead of
the index dialect. Fixes #171303.

Added: 
    mlir/test/Conversion/GPUCommon/lower-global-id.mlir

Modified: 
    mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
    mlir/test/Dialect/GPU/globalId-rewrite.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
index 6519b65cec465..e55a695be13c8 100644
--- a/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/GlobalIdRewriter.cpp
@@ -11,9 +11,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
-#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
@@ -26,13 +26,15 @@ struct GpuGlobalIdRewriter : public OpRewritePattern<gpu::GlobalIdOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     auto dim = op.getDimension();
-    auto blockId = gpu::BlockIdOp::create(rewriter, loc, dim);
-    auto blockDim = gpu::BlockDimOp::create(rewriter, loc, dim);
+    Value blockId = gpu::BlockIdOp::create(rewriter, loc, dim);
+    Value blockDim = gpu::BlockDimOp::create(rewriter, loc, dim);
+    auto indexType = rewriter.getIndexType();
     // Compute blockId.x * blockDim.x
-    auto tmp = index::MulOp::create(rewriter, op.getLoc(), blockId, blockDim);
-    auto threadId = gpu::ThreadIdOp::create(rewriter, loc, dim);
+    Value tmp =
+        arith::MulIOp::create(rewriter, loc, indexType, blockId, blockDim);
+    Value threadId = gpu::ThreadIdOp::create(rewriter, loc, dim);
     // Compute threadId.x + blockId.x * blockDim.x
-    rewriter.replaceOpWithNewOp<index::AddOp>(op, threadId, tmp);
+    rewriter.replaceOpWithNewOp<arith::AddIOp>(op, indexType, threadId, tmp);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/GPUCommon/lower-global-id.mlir b/mlir/test/Conversion/GPUCommon/lower-global-id.mlir
new file mode 100644
index 0000000000000..b0274e0f9f290
--- /dev/null
+++ b/mlir/test/Conversion/GPUCommon/lower-global-id.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -split-input-file -convert-gpu-to-rocdl | FileCheck %s --check-prefixes=ROCDL
+// RUN: mlir-opt %s -split-input-file -convert-gpu-to-nvvm | FileCheck %s --check-prefixes=NVVM
+
+gpu.module @kernel {
+  gpu.func @gpu_global_id() -> (index) {
+    %global_id_x = gpu.global_id x
+    gpu.return %global_id_x : index
+  }
+}
+
+// ROCDL-LABEL:   llvm.func @gpu_global_id() -> i64 {
+// ROCDL:           %[[WORKGROUP_0:.*]] = rocdl.workgroup.id.x : i32
+// ROCDL:           %[[SEXT_0:.*]] = llvm.sext %[[WORKGROUP_0]] : i32 to i64
+// ROCDL:           %[[WORKGROUP_1:.*]] = rocdl.workgroup.dim.x : i32
+// ROCDL:           %[[SEXT_1:.*]] = llvm.sext %[[WORKGROUP_1]] : i32 to i64
+// ROCDL:           %[[MUL_0:.*]] = llvm.mul %[[SEXT_0]], %[[SEXT_1]] : i64
+// ROCDL:           %[[WORKITEM_0:.*]] = rocdl.workitem.id.x : i32
+// ROCDL:           %[[SEXT_2:.*]] = llvm.sext %[[WORKITEM_0]] : i32 to i64
+// ROCDL:           %[[ADD_0:.*]] = llvm.add %[[SEXT_2]], %[[MUL_0]] : i64
+// ROCDL:           llvm.return %[[ADD_0]] : i64
+// ROCDL:         }
+
+// NVVM-LABEL:   llvm.func @gpu_global_id() -> i64 {
+// NVVM:           %[[READ_0:.*]] = nvvm.read.ptx.sreg.ctaid.x : i32
+// NVVM:           %[[SEXT_0:.*]] = llvm.sext %[[READ_0]] : i32 to i64
+// NVVM:           %[[READ_1:.*]] = nvvm.read.ptx.sreg.ntid.x : i32
+// NVVM:           %[[SEXT_1:.*]] = llvm.sext %[[READ_1]] : i32 to i64
+// NVVM:           %[[MUL_0:.*]] = llvm.mul %[[SEXT_0]], %[[SEXT_1]] : i64
+// NVVM:           %[[READ_2:.*]] = nvvm.read.ptx.sreg.tid.x : i32
+// NVVM:           %[[SEXT_2:.*]] = llvm.sext %[[READ_2]] : i32 to i64
+// NVVM:           %[[ADD_0:.*]] = llvm.add %[[SEXT_2]], %[[MUL_0]] : i64
+// NVVM:           llvm.return %[[ADD_0]] : i64
+// NVVM:         }

diff  --git a/mlir/test/Dialect/GPU/globalId-rewrite.mlir b/mlir/test/Dialect/GPU/globalId-rewrite.mlir
index 9e02d69daa436..d7d080a0093aa 100644
--- a/mlir/test/Dialect/GPU/globalId-rewrite.mlir
+++ b/mlir/test/Dialect/GPU/globalId-rewrite.mlir
@@ -8,27 +8,27 @@ module {
                threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
       // CHECK: %[[BIDY:.*]] = gpu.block_id x
       // CHECK-NEXT: %[[BDIMY:.*]] = gpu.block_dim x
-      // CHECK-NEXT: %[[TMPY:.*]] = index.mul %[[BIDY]], %[[BDIMY]]
+      // CHECK-NEXT: %[[TMPY:.*]] = arith.muli %[[BIDY]], %[[BDIMY]]
       // CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x
-      // CHECK-NEXT: %[[GIDX:.*]] = index.add %[[TIDX]], %[[TMPY]]
+      // CHECK-NEXT: %[[GIDX:.*]] = arith.addi %[[TIDX]], %[[TMPY]]
       %idx = gpu.global_id x
       // CHECK: memref.store %[[GIDX]], %[[MEM]][] : memref<index, 1>
       memref.store %idx, %mem[] : memref<index, 1>
   
       // CHECK: %[[BIDY:.*]] = gpu.block_id y
       // CHECK-NEXT: %[[BDIMY:.*]] = gpu.block_dim y
-      // CHECK-NEXT: %[[TMPY:.*]] = index.mul %[[BIDY]], %[[BDIMY]]
+      // CHECK-NEXT: %[[TMPY:.*]] = arith.muli %[[BIDY]], %[[BDIMY]]
       // CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y
-      // CHECK-NEXT: %[[GIDY:.*]] = index.add %[[TIDY]], %[[TMPY]]
+      // CHECK-NEXT: %[[GIDY:.*]] = arith.addi %[[TIDY]], %[[TMPY]]
       %idy = gpu.global_id y
       // CHECK: memref.store %[[GIDY]], %[[MEM]][] : memref<index, 1>
       memref.store %idy, %mem[] : memref<index, 1>
   
       // CHECK: %[[BIDZ:.*]] = gpu.block_id z
       // CHECK-NEXT: %[[BDIMZ:.*]] = gpu.block_dim z
-      // CHECK-NEXT: %[[TMPZ:.*]] = index.mul %[[BIDZ]], %[[BDIMZ]]
+      // CHECK-NEXT: %[[TMPZ:.*]] = arith.muli %[[BIDZ]], %[[BDIMZ]]
       // CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z
-      // CHECK-NEXT: %[[GIDZ:.*]] = index.add %[[TIDZ]], %[[TMPZ]]
+      // CHECK-NEXT: %[[GIDZ:.*]] = arith.addi %[[TIDZ]], %[[TMPZ]]
       %idz = gpu.global_id z
       // CHECK: memref.store %[[GIDZ]], %[[MEM]][] : memref<index, 1>
       memref.store %idz, %mem[] : memref<index, 1>


        


More information about the Mlir-commits mailing list