[Mlir-commits] [mlir] [ROCDL] Added wave.id to rocdl; add `rsq` to rocdl.math (PR #176028)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 14 12:25:02 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Ravil Dorozhinskii (ravil-mobile)

<details>
<summary>Changes</summary>

PR adds `wave.id` to rocdl; add `rsq` to rocdl.math; and a fix to global/flat prefetch (`MemWrite` trait)

---
Full diff: https://github.com/llvm/llvm-project/pull/176028.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+4-2) 
- (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+9) 
- (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+12-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 265c2e99f52d6..63b3c62427b9f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -351,6 +351,7 @@ def ROCDL_ClusterIdXOp : ROCDL_SpecialIdRegisterOp<"cluster.id.x">;
 def ROCDL_ClusterIdYOp : ROCDL_SpecialIdRegisterOp<"cluster.id.y">;
 def ROCDL_ClusterIdZOp : ROCDL_SpecialIdRegisterOp<"cluster.id.z">;
 
+def ROCDL_WaveId : ROCDL_SpecialIdRegisterOp<"wave.id">;
 def ROCDL_WavefrontSizeOp : ROCDL_SpecialIdRegisterOp<"wavefrontsize">;
 
 //===----------------------------------------------------------------------===//
@@ -1298,7 +1299,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
 
 def ROCDL_GlobalPrefetchOp :
   ROCDL_IntrOp<"global.prefetch", [], [], [], 0, 0, 1, 0, [1], ["scope"]> {
-  dag args = (ins Arg<LLVM_PointerInAddressSpace<1>, "", [MemRead]>:$ptr,
+  dag args = (ins Arg<LLVM_PointerInAddressSpace<1>, "", [MemRead, MemWrite]>:$ptr,
                   I32Attr:$scope);
   let arguments = !con(args, baseArgs);
   let description = [{
@@ -1316,7 +1317,7 @@ def ROCDL_GlobalPrefetchOp :
 
 def ROCDL_FlatPrefetchOp :
   ROCDL_IntrOp<"flat.prefetch", [], [], [], 0, 0, 1, 0, [1], ["scope"]> {
-  dag args = (ins Arg<LLVM_PointerInAddressSpace<0>, "", [MemRead]>:$ptr,
+  dag args = (ins Arg<LLVM_PointerInAddressSpace<0>, "", [MemRead, MemWrite]>:$ptr,
                   I32Attr:$scope);
   let arguments = !con(args, baseArgs);
   let description = [{
@@ -2140,6 +2141,7 @@ def ROCDLExp : ROCDL_Math_IntrOp<"exp">;
 def ROCDLExp2 : ROCDL_Math_IntrOp<"exp2">;
 def ROCDLLog : ROCDL_Math_IntrOp<"log">;
 def ROCDLSqrt : ROCDL_Math_IntrOp<"sqrt">;
+def ROCDLRsq : ROCDL_Math_IntrOp<"rsq">;
 
 //===----------------------------------------------------------------------===//
 // ROCDL target attribute.
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index cf2b144219f36..9e7fb70a4a6cf 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -32,6 +32,8 @@ func.func @rocdl_special_regs() -> i32 {
   %13 = rocdl.grid.dim.y : i32
   // CHECK: rocdl.grid.dim.z : i32
   %14 = rocdl.grid.dim.z : i32
+  // CHECK: rocdl.wave.id : i32
+  %15 = rocdl.wave.id : i32
   llvm.return %0 : i32
 }
 
@@ -99,6 +101,13 @@ func.func @rocdl.math.ops(%a: f32, %b: f16, %c: bf16) {
   %sqrt0 = rocdl.sqrt %a f32 -> f32
   %sqrt1 = rocdl.sqrt %b f16 -> f16
   %sqrt2 = rocdl.sqrt %c bf16 -> bf16
+
+  // CHECK: %{{.*}} = rocdl.rsq %{{.*}} f32 -> f32
+  // CHECK: %{{.*}} = rocdl.rsq %{{.*}} f16 -> f16
+  // CHECK: %{{.*}} = rocdl.rsq %{{.*}} bf16 -> bf16
+  %rsq0 = rocdl.rsq %a f32 -> f32
+  %rsq1 = rocdl.rsq %b f16 -> f16
+  %rsq2 = rocdl.rsq %c bf16 -> bf16
   llvm.return
 }
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index dc6a00e19afc3..4b2f0dd68799a 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -39,11 +39,14 @@ llvm.func @rocdl_special_regs() -> i32 {
   // CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0)
   %17 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64
 
+  // CHECK: call i32 @llvm.amdgcn.wave.id()
+  %18 = rocdl.wave.id : i32
+
   // CHECK: call i32 @llvm.amdgcn.wavefrontsize()
-  %18 = rocdl.wavefrontsize : i32
+  %19 = rocdl.wavefrontsize : i32
 
   // CHECK: call range(i32 32, 65) i32 @llvm.amdgcn.wavefrontsize()
-  %19 = rocdl.wavefrontsize range <i32, 32, 65> : i32
+  %20 = rocdl.wavefrontsize range <i32, 32, 65> : i32
 
   llvm.return %1 : i32
 }
@@ -111,6 +114,13 @@ llvm.func @kernel_math_ops(%a: f32, %b: f16, %c: bf16) {
   %sqrt0 = rocdl.sqrt %a f32 -> f32
   %sqrt1 = rocdl.sqrt %b f16 -> f16
   %sqrt2 = rocdl.sqrt %c bf16 -> bf16
+
+  // CHECK: call float @llvm.amdgcn.rsq.f32(float %{{.*}})
+  // CHECK: call half @llvm.amdgcn.rsq.f16(half %{{.*}})
+  // CHECK: call bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
+  %rsq0 = rocdl.rsq %a f32 -> f32
+  %rsq1 = rocdl.rsq %b f16 -> f16
+  %rsq2 = rocdl.rsq %c bf16 -> bf16
   llvm.return
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/176028


More information about the Mlir-commits mailing list