[Mlir-commits] [mlir] 7305ed7 - [ROCDL] added math instructions to the ROCDL dialect (#169672)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 8 10:34:17 PST 2025


Author: Ravil Dorozhinskii
Date: 2025-12-08T19:34:13+01:00
New Revision: 7305ed7e1554d5af652bdf7afc731f217cc49f09

URL: https://github.com/llvm/llvm-project/commit/7305ed7e1554d5af652bdf7afc731f217cc49f09
DIFF: https://github.com/llvm/llvm-project/commit/7305ed7e1554d5af652bdf7afc731f217cc49f09.diff

LOG: [ROCDL] added math instructions to the ROCDL dialect (#169672)

Exposed llvm amdgcn math intrinsic calls through ROCDL

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/test/Dialect/LLVMIR/rocdl.mlir
    mlir/test/Target/LLVMIR/rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0edb208a8fcba..cd36300d7ac16 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -1913,6 +1913,33 @@ def ROCDL_FMed3Op : ROCDL_IntrOp<"fmed3", [0], [], [Pure, AllTypesMatch<["res",
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Math operations
+//===----------------------------------------------------------------------===//
+
+class ROCDL_Math_IntrOp<string mnemonic, list<Trait> traits = [Pure]> :
+  ROCDL_IntrOp<mnemonic, [0], [], traits, 1>,
+  Arguments<(ins LLVM_AnyFloat:$arg)> {
+  let results = (outs LLVM_AnyFloat:$res);
+  let description = [{
+    Note: In the general case, prefer the conventional `arith`, `math`, or `llvm` ops over this.
+    Use this ROCDL-specific operation only when you fully understand its implication and
+    when it is strictly necessary. This op is usually chosen when a small loss in precision is
+    acceptable in exchange for higher execution speed.
+  }];
+  let assemblyFormat =
+    "$arg qualified(type($arg)) attr-dict `->` qualified(type($res))";
+}
+
+def ROCDLTanh : ROCDL_Math_IntrOp<"tanh">;
+def ROCDLSin : ROCDL_Math_IntrOp<"sin">;
+def ROCDLCos : ROCDL_Math_IntrOp<"cos">;
+def ROCDLRcp : ROCDL_Math_IntrOp<"rcp">;
+def ROCDLExp : ROCDL_Math_IntrOp<"exp">;
+def ROCDLExp2 : ROCDL_Math_IntrOp<"exp2">;
+def ROCDLLog : ROCDL_Math_IntrOp<"log">;
+def ROCDLSqrt : ROCDL_Math_IntrOp<"sqrt">;
+
 //===----------------------------------------------------------------------===//
 // ROCDL target attribute.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 1b50feea418b6..40084bc07d4f7 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -49,6 +49,59 @@ func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4
   llvm.return %0 : vector<4xf16>
 }
 
+func.func @rocdl.math.ops(%a: f32, %b: f16, %c: bf16) {
+  // CHECK-LABEL: rocdl.math.ops
+  // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f32 -> f32
+  // CHECK: %{{.*}} = rocdl.tanh %{{.*}} f16 -> f16
+  // CHECK: %{{.*}} = rocdl.tanh %{{.*}} bf16 -> bf16
+  %tanh0 = rocdl.tanh %a f32 -> f32
+  %tanh1 = rocdl.tanh %b f16 -> f16
+  %tanh2 = rocdl.tanh %c bf16 -> bf16
+
+  // CHECK: %{{.*}} = rocdl.sin %{{.*}} f32 -> f32
+  // CHECK: %{{.*}} = rocdl.sin %{{.*}} f16 -> f16
+  // CHECK: %{{.*}} = rocdl.sin %{{.*}} bf16 -> bf16
+  %sin0 = rocdl.sin %a f32 -> f32
+  %sin1 = rocdl.sin %b f16 -> f16
+  %sin2 = rocdl.sin %c bf16 -> bf16
+
+  // CHECK: %{{.*}} = rocdl.cos %{{.*}} f32 -> f32
+  // CHECK: %{{.*}} = rocdl.cos %{{.*}} f16 -> f16
+  // CHECK: %{{.*}} = rocdl.cos %{{.*}} bf16 -> bf16
+  %cos0 = rocdl.cos %a f32 -> f32
+  %cos1 = rocdl.cos %b f16 -> f16
+  %cos2 = rocdl.cos %c bf16 -> bf16
+
+  // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f32 -> f32
+  // CHECK: %{{.*}} = rocdl.rcp %{{.*}} f16 -> f16
+  // CHECK: %{{.*}} = rocdl.rcp %{{.*}} bf16 -> bf16
+  %rcp0 = rocdl.rcp %a f32 -> f32
+  %rcp1 = rocdl.rcp %b f16 -> f16
+  %rcp2 = rocdl.rcp %c bf16 -> bf16
+
+  // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f32 -> f32
+  // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} f16 -> f16
+  // CHECK: %{{.*}} = rocdl.exp2 %{{.*}} bf16 -> bf16
+  %exp2_0 = rocdl.exp2 %a f32 -> f32
+  %exp2_1 = rocdl.exp2 %b f16 -> f16
+  %exp2_2 = rocdl.exp2 %c bf16 -> bf16
+
+  // CHECK: %{{.*}} = rocdl.log %{{.*}} f32 -> f32
+  // CHECK: %{{.*}} = rocdl.log %{{.*}} f16 -> f16
+  // CHECK: %{{.*}} = rocdl.log %{{.*}} bf16 -> bf16
+  %log0 = rocdl.log %a f32 -> f32
+  %log1 = rocdl.log %b f16 -> f16
+  %log2 = rocdl.log %c bf16 -> bf16
+
+  // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f32 -> f32
+  // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} f16 -> f16
+  // CHECK: %{{.*}} = rocdl.sqrt %{{.*}} bf16 -> bf16
+  %sqrt0 = rocdl.sqrt %a f32 -> f32
+  %sqrt1 = rocdl.sqrt %b f16 -> f16
+  %sqrt2 = rocdl.sqrt %c bf16 -> bf16
+  llvm.return
+}
+
 func.func @rocdl.barrier() {
   // CHECK: rocdl.barrier
   rocdl.barrier

diff  --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 7be6d6ba4d7be..2c748ad509356 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -61,6 +61,59 @@ llvm.func @kernel_func_workgroups()
   llvm.return
 }
 
+llvm.func @kernel_math_ops(%a: f32, %b: f16, %c: bf16) {
+  // CHECK-LABEL: kernel_math_ops
+  // CHECK: call float @llvm.amdgcn.tanh.f32(float %{{.*}})
+  // CHECK: call half @llvm.amdgcn.tanh.f16(half %{{.*}})
+  // CHECK: call bfloat @llvm.amdgcn.tanh.bf16(bfloat %{{.*}})
+  %tanh0 = rocdl.tanh %a f32 -> f32
+  %tanh1 = rocdl.tanh %b f16 -> f16
+  %tanh2 = rocdl.tanh %c bf16 -> bf16
+
+  // CHECK: call float @llvm.amdgcn.sin.f32(float %{{.*}})
+  // CHECK: call half @llvm.amdgcn.sin.f16(half %{{.*}})
+  // CHECK: call bfloat @llvm.amdgcn.sin.bf16(bfloat %{{.*}})
+  %sin0 = rocdl.sin %a f32 -> f32
+  %sin1 = rocdl.sin %b f16 -> f16
+  %sin2 = rocdl.sin %c bf16 -> bf16
+
+  // CHECK: call float @llvm.amdgcn.cos.f32(float %{{.*}})
+  // CHECK: call half @llvm.amdgcn.cos.f16(half %{{.*}})
+  // CHECK: call bfloat @llvm.amdgcn.cos.bf16(bfloat %{{.*}})
+  %cos0 = rocdl.cos %a f32 -> f32
+  %cos1 = rocdl.cos %b f16 -> f16
+  %cos2 = rocdl.cos %c bf16 -> bf16
+
+  // CHECK: call float @llvm.amdgcn.rcp.f32(float %{{.*}})
+  // CHECK: call half @llvm.amdgcn.rcp.f16(half %{{.*}})
+  // CHECK: call bfloat @llvm.amdgcn.rcp.bf16(bfloat %{{.*}})
+  %rcp0 = rocdl.rcp %a f32 -> f32
+  %rcp1 = rocdl.rcp %b f16 -> f16
+  %rcp2 = rocdl.rcp %c bf16 -> bf16
+
+  // CHECK: call float @llvm.amdgcn.exp2.f32(float %{{.*}})
+  // CHECK: call half @llvm.amdgcn.exp2.f16(half %{{.*}})
+  // CHECK: call bfloat @llvm.amdgcn.exp2.bf16(bfloat %{{.*}})
+  %exp2_0 = rocdl.exp2 %a f32 -> f32
+  %exp2_1 = rocdl.exp2 %b f16 -> f16
+  %exp2_2 = rocdl.exp2 %c bf16 -> bf16
+
+  // CHECK: call float @llvm.amdgcn.log.f32(float %{{.*}})
+  // CHECK: call half @llvm.amdgcn.log.f16(half %{{.*}})
+  // CHECK: call bfloat @llvm.amdgcn.log.bf16(bfloat %{{.*}})
+  %log0 = rocdl.log %a f32 -> f32
+  %log1 = rocdl.log %b f16 -> f16
+  %log2 = rocdl.log %c bf16 -> bf16
+
+  // CHECK: call float @llvm.amdgcn.sqrt.f32(float %{{.*}})
+  // CHECK: call half @llvm.amdgcn.sqrt.f16(half %{{.*}})
+  // CHECK: call bfloat @llvm.amdgcn.sqrt.bf16(bfloat %{{.*}})
+  %sqrt0 = rocdl.sqrt %a f32 -> f32
+  %sqrt1 = rocdl.sqrt %b f16 -> f16
+  %sqrt2 = rocdl.sqrt %c bf16 -> bf16
+  llvm.return
+}
+
 llvm.func @known_block_sizes()
     attributes {rocdl.kernel,
       rocdl.flat_work_group_size = "128,128",


        


More information about the Mlir-commits mailing list