[Mlir-commits] [mlir] [MLIR][ROCDL] Add conversion for gpu.subgroup_id to ROCDL (PR #136405)

Jakub Kuderski llvmlistbot at llvm.org
Wed Apr 23 08:07:04 PDT 2025


================
@@ -190,6 +201,49 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
   }
 };
 
+struct GPUSubgroupIdOpToROCDL final
+    : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  GPUSubgroupIdOpToROCDL(const LLVMTypeConverter &converter,
+                         const mlir::amdgpu::Chipset &chipset)
+      : ConvertOpToLLVMPattern(converter), chipset(chipset) {}
+
+  const mlir::amdgpu::Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto int32Type = IntegerType::get(rewriter.getContext(), 32);
+    auto loc = op.getLoc();
+    LLVM::IntegerOverflowFlags flags =
+        LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
+    // w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) / subgroup_size
+    Value workitemIdX = rewriter.create<ROCDL::ThreadIdXOp>(loc, int32Type);
+    Value workitemIdY = rewriter.create<ROCDL::ThreadIdYOp>(loc, int32Type);
+    Value workitemIdZ = rewriter.create<ROCDL::ThreadIdZOp>(loc, int32Type);
+    Value workitemDimX = rewriter.create<ROCDL::BlockDimXOp>(loc, int32Type);
+    Value workitemDimY = rewriter.create<ROCDL::BlockDimYOp>(loc, int32Type);
+    Value dimYxIdZ = rewriter.create<LLVM::MulOp>(loc, int32Type, workitemDimY,
+                                                  workitemIdZ, flags);
+    Value dimYxIdZPlusIdY = rewriter.create<LLVM::AddOp>(
+        loc, int32Type, dimYxIdZ, workitemIdY, flags);
+    Value dimYxIdZPlusIdYTimesDimX = rewriter.create<LLVM::MulOp>(
+        loc, int32Type, workitemDimX, dimYxIdZPlusIdY, flags);
+    Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX =
+        rewriter.create<LLVM::AddOp>(loc, int32Type, workitemIdX,
+                                     dimYxIdZPlusIdYTimesDimX, flags);
+    Value subgroupSize = rewriter.create<LLVM::ConstantOp>(
+        loc, IntegerType::get(rewriter.getContext(), 32), 64);
----------------
kuhar wrote:

Why is it safe to hardcode this for gfx9 here?

https://github.com/llvm/llvm-project/pull/136405


More information about the Mlir-commits mailing list