[Mlir-commits] [mlir] [ROCDL] Added wave.id to rocdl; add `rsq` to rocdl.math (PR #176028)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 14 12:25:02 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Ravil Dorozhinskii (ravil-mobile)
<details>
<summary>Changes</summary>
PR adds `wave.id` to rocdl; add `rsq` to rocdl.math; and a fix to global/flat prefetch (`MemWrite` trait)
---
Full diff: https://github.com/llvm/llvm-project/pull/176028.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+4-2)
- (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+9)
- (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+12-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 265c2e99f52d6..63b3c62427b9f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -351,6 +351,7 @@ def ROCDL_ClusterIdXOp : ROCDL_SpecialIdRegisterOp<"cluster.id.x">;
def ROCDL_ClusterIdYOp : ROCDL_SpecialIdRegisterOp<"cluster.id.y">;
def ROCDL_ClusterIdZOp : ROCDL_SpecialIdRegisterOp<"cluster.id.z">;
+def ROCDL_WaveId : ROCDL_SpecialIdRegisterOp<"wave.id">;
def ROCDL_WavefrontSizeOp : ROCDL_SpecialIdRegisterOp<"wavefrontsize">;
//===----------------------------------------------------------------------===//
@@ -1298,7 +1299,7 @@ def ROCDL_RawBufferAtomicCmpSwap :
def ROCDL_GlobalPrefetchOp :
ROCDL_IntrOp<"global.prefetch", [], [], [], 0, 0, 1, 0, [1], ["scope"]> {
- dag args = (ins Arg<LLVM_PointerInAddressSpace<1>, "", [MemRead]>:$ptr,
+ dag args = (ins Arg<LLVM_PointerInAddressSpace<1>, "", [MemRead, MemWrite]>:$ptr,
I32Attr:$scope);
let arguments = !con(args, baseArgs);
let description = [{
@@ -1316,7 +1317,7 @@ def ROCDL_GlobalPrefetchOp :
def ROCDL_FlatPrefetchOp :
ROCDL_IntrOp<"flat.prefetch", [], [], [], 0, 0, 1, 0, [1], ["scope"]> {
- dag args = (ins Arg<LLVM_PointerInAddressSpace<0>, "", [MemRead]>:$ptr,
+ dag args = (ins Arg<LLVM_PointerInAddressSpace<0>, "", [MemRead, MemWrite]>:$ptr,
I32Attr:$scope);
let arguments = !con(args, baseArgs);
let description = [{
@@ -2140,6 +2141,7 @@ def ROCDLExp : ROCDL_Math_IntrOp<"exp">;
def ROCDLExp2 : ROCDL_Math_IntrOp<"exp2">;
def ROCDLLog : ROCDL_Math_IntrOp<"log">;
def ROCDLSqrt : ROCDL_Math_IntrOp<"sqrt">;
+def ROCDLRsq : ROCDL_Math_IntrOp<"rsq">;
//===----------------------------------------------------------------------===//
// ROCDL target attribute.
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index cf2b144219f36..9e7fb70a4a6cf 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -32,6 +32,8 @@ func.func @rocdl_special_regs() -> i32 {
%13 = rocdl.grid.dim.y : i32
// CHECK: rocdl.grid.dim.z : i32
%14 = rocdl.grid.dim.z : i32
+ // CHECK: rocdl.wave.id : i32
+ %15 = rocdl.wave.id : i32
llvm.return %0 : i32
}
@@ -99,6 +101,13 @@ func.func @rocdl.math.ops(%a: f32, %b: f16, %c: bf16) {
%sqrt0 = rocdl.sqrt %a f32 -> f32
%sqrt1 = rocdl.sqrt %b f16 -> f16
%sqrt2 = rocdl.sqrt %c bf16 -> bf16
+
+ // CHECK: %{{.*}} = rocdl.rsq %{{.*}} f32 -> f32
+ // CHECK: %{{.*}} = rocdl.rsq %{{.*}} f16 -> f16
+ // CHECK: %{{.*}} = rocdl.rsq %{{.*}} bf16 -> bf16
+ %rsq0 = rocdl.rsq %a f32 -> f32
+ %rsq1 = rocdl.rsq %b f16 -> f16
+ %rsq2 = rocdl.rsq %c bf16 -> bf16
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index dc6a00e19afc3..4b2f0dd68799a 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -39,11 +39,14 @@ llvm.func @rocdl_special_regs() -> i32 {
// CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0)
%17 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64
+ // CHECK: call i32 @llvm.amdgcn.wave.id()
+ %18 = rocdl.wave.id : i32
+
// CHECK: call i32 @llvm.amdgcn.wavefrontsize()
- %18 = rocdl.wavefrontsize : i32
+ %19 = rocdl.wavefrontsize : i32
// CHECK: call range(i32 32, 65) i32 @llvm.amdgcn.wavefrontsize()
- %19 = rocdl.wavefrontsize range <i32, 32, 65> : i32
+ %20 = rocdl.wavefrontsize range <i32, 32, 65> : i32
llvm.return %1 : i32
}
@@ -111,6 +114,13 @@ llvm.func @kernel_math_ops(%a: f32, %b: f16, %c: bf16) {
%sqrt0 = rocdl.sqrt %a f32 -> f32
%sqrt1 = rocdl.sqrt %b f16 -> f16
%sqrt2 = rocdl.sqrt %c bf16 -> bf16
+
+ // CHECK: call float @llvm.amdgcn.rsq.f32(float %{{.*}})
+ // CHECK: call half @llvm.amdgcn.rsq.f16(half %{{.*}})
+ // CHECK: call bfloat @llvm.amdgcn.rsq.bf16(bfloat %{{.*}})
+ %rsq0 = rocdl.rsq %a f32 -> f32
+ %rsq1 = rocdl.rsq %b f16 -> f16
+ %rsq2 = rocdl.rsq %c bf16 -> bf16
llvm.return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/176028
More information about the Mlir-commits
mailing list