[Mlir-commits] [mlir] 0ba1361 - [MLIR][GPU] Use arith instead of index for subgroup_id (#137843)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 30 06:03:27 PDT 2025


Author: Alan Li
Date: 2025-04-30T09:03:24-04:00
New Revision: 0ba136147814cb8bc19dfa1712dad3a25b3ae27a

URL: https://github.com/llvm/llvm-project/commit/0ba136147814cb8bc19dfa1712dad3a25b3ae27a
DIFF: https://github.com/llvm/llvm-project/commit/0ba136147814cb8bc19dfa1712dad3a25b3ae27a.diff

LOG: [MLIR][GPU] Use arith instead of index for subgroup_id (#137843)

Trying to simplify situation by using `arith` dialect instead of `index`
in the rewriting of `gpu.subgroup_id`.

Added: 
    

Modified: 
    mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
    mlir/test/Dialect/GPU/subgroupId-rewrite.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
index 0f0df08919553..d80578235f3c3 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp
@@ -53,6 +53,7 @@ struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
     //             subgroup_size
 
     Location loc = op->getLoc();
+    Type indexType = rewriter.getIndexType();
 
     Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
     Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
@@ -60,16 +61,17 @@ struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
     Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
     Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);
 
-    Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ);
-    Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
+    Value dimYxIdZ = rewriter.create<arith::MulIOp>(loc, indexType, dimY, tidZ);
+    Value dimYxIdZPlusIdY =
+        rewriter.create<arith::AddIOp>(loc, indexType, dimYxIdZ, tidY);
     Value dimYxIdZPlusIdYTimesDimX =
-        rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
-    Value IdXPlusDimYxIdZPlusIdYTimesDimX =
-        rewriter.create<index::AddOp>(loc, tidX, dimYxIdZPlusIdYTimesDimX);
+        rewriter.create<arith::MulIOp>(loc, indexType, dimX, dimYxIdZPlusIdY);
+    Value IdXPlusDimYxIdZPlusIdYTimesDimX = rewriter.create<arith::AddIOp>(
+        loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX);
     Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
         loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
-    Value subgroupIdOp = rewriter.create<index::DivUOp>(
-        loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
+    Value subgroupIdOp = rewriter.create<arith::DivUIOp>(
+        loc, indexType, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
     rewriter.replaceOp(op, {subgroupIdOp});
     return success();
   }

diff  --git a/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
index a0c852f6fbe88..386793ad88649 100644
--- a/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
+++ b/mlir/test/Dialect/GPU/subgroupId-rewrite.mlir
@@ -10,12 +10,12 @@ func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
     // CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id  x
     // CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id  y
     // CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id  z
-    // CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
-    // CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
-    // CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
-    // CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
+    // CHECK-NEXT: %[[T0:.*]] = arith.muli %[[DIMY]], %[[TIDZ]] : index
+    // CHECK-NEXT: %[[T1:.*]] = arith.addi %[[T0]], %[[TIDY]] : index
+    // CHECK-NEXT: %[[T2:.*]] = arith.muli %[[DIMX]], %[[T1]] : index
+    // CHECK-NEXT: %[[T3:.*]] = arith.addi %[[TIDX]], %[[T2]] : index
     // CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
-    // CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
+    // CHECK-NEXT: %[[T5:.*]] = arith.divui %[[T3]], %[[T4]] : index
     %idz = gpu.subgroup_id : index
     memref.store %idz, %mem[] : memref<index, 1>
     gpu.terminator


        


More information about the Mlir-commits mailing list