[Mlir-commits] [mlir] [ROCDL][LLVM] Added rocdl.fmed3 -> Intrinsic::amdgcn_fmed3 (PR #159332)

Keshav Vinayak Jha llvmlistbot at llvm.org
Sat Sep 27 06:21:46 PDT 2025


https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/159332

>From 285f5027e39bac2570c914fb5fb96d2658f559f0 Mon Sep 17 00:00:00 2001
From: keshavvinayak01 <keshavvinayakjha at gmail.com>
Date: Wed, 17 Sep 2025 11:40:49 +0000
Subject: [PATCH 1/2] Added fmed3 rocdl op

Signed-off-by: keshavvinayak01 <keshavvinayakjha at gmail.com>
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 20 ++++++++++++++++++++
 mlir/test/Target/LLVMIR/rocdl.mlir           | 14 ++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 9fa3ec1fc4b21..1d31ec069b5c0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -1291,6 +1291,26 @@ def ROCDL_CvtScaleF32PkFp4F32Op :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// MED3 operations
+//===----------------------------------------------------------------------===//
+
+def ROCDL_Med3Op : ROCDL_ConcreteNonMemIntrOp<"med3", [Pure, AllTypesMatch<["res", "src0", "src1", "src2"]>], 1>,
+  Arguments<(ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src0,
+                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src1,
+                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src2)> {
+  let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$res);
+  let summary = "Median of three float/half values";
+  let assemblyFormat = [{
+    $src0 `,` $src1 `,` $src2 attr-dict `:` `(` type($src0) `,` type($src1) `,` type($src2) `)` `->` type($res)
+  }];
+  string llvmBuilder = [{
+    $res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_fmed3,
+                               {$src0, $src1, $src2},
+                               {moduleTranslation.convertType(op.getRes().getType())});
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ROCDL target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index a464358250c38..579669f646ceb 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1298,6 +1298,20 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 {
   llvm.return %ret : i32
 }
 
+llvm.func @test_med3_f16(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
+  // CHECK-LABEL: define half @test_med3_f16(half %0, half %1, half %2)
+  %0 = rocdl.med3 %arg0, %arg1, %arg2 : (f16, f16, f16) -> f16
+  llvm.return %0 : f16
+  // CHECK: call half @llvm.amdgcn.fmed3.f16(half %0, half %1, half %2)
+}
+
+llvm.func @test_med3_f32(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
+  // CHECK-LABEL: define float @test_med3_f32(float %0, float %1, float %2)
+  %0 = rocdl.med3 %arg0, %arg1, %arg2 : (f32, f32, f32) -> f32
+  llvm.return %0 : f32
+  // CHECK: call float @llvm.amdgcn.fmed3.f32(float %0, float %1, float %2)
+}
+
 // CHECK-DAG: attributes #[[$KERNEL_ATTRS]] = { "amdgpu-flat-work-group-size"="1,256" "uniform-work-group-size"="true" }
 // CHECK-DAG: attributes #[[$KERNEL_WORKGROUP_ATTRS]] = { "amdgpu-flat-work-group-size"="1,1024"
 // CHECK-DAG: attributes #[[$KNOWN_BLOCK_SIZE_ATTRS]] = { "amdgpu-flat-work-group-size"="128,128"

>From 829dc0ff8d7db139b7497531430612868291dd2a Mon Sep 17 00:00:00 2001
From: keshavvinayak01 <keshavvinayakjha at gmail.com>
Date: Sun, 21 Sep 2025 10:29:30 +0000
Subject: [PATCH 2/2] Variadic rocdl.fmed3 op; print tests; addresed comments

Signed-off-by: keshavvinayak01 <keshavvinayakjha at gmail.com>
---
 mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 27 ++++++++++++++------
 mlir/test/Dialect/LLVMIR/rocdl.mlir          | 14 ++++++++++
 mlir/test/Target/LLVMIR/rocdl.mlir           | 12 ++++-----
 3 files changed, 39 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 1d31ec069b5c0..8def3b25f5c28 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -1292,22 +1292,33 @@ def ROCDL_CvtScaleF32PkFp4F32Op :
 }
 
 //===----------------------------------------------------------------------===//
-// MED3 operations
+// FMED3 operations
 //===----------------------------------------------------------------------===//
 
-def ROCDL_Med3Op : ROCDL_ConcreteNonMemIntrOp<"med3", [Pure, AllTypesMatch<["res", "src0", "src1", "src2"]>], 1>,
+def ROCDL_FMed3Op : ROCDL_IntrOp<"fmed3", [0], [], [Pure, AllTypesMatch<["res", "src0", "src1", "src2"]>], 1>,
   Arguments<(ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src0,
                  LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src1,
                  LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$src2)> {
   let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$res);
   let summary = "Median of three float/half values";
-  let assemblyFormat = [{
-    $src0 `,` $src1 `,` $src2 attr-dict `:` `(` type($src0) `,` type($src1) `,` type($src2) `)` `->` type($res)
+  let description = [{
+    Computes the median of three floating-point values using the AMDGPU fmed3 intrinsic.
+    This operation is equivalent to `max(min(a, b), min(max(a, b), c))` but uses the
+    hardware-accelerated V_MED3_F16/V_MED3_F32 instruction for better performance.
+    
+    The operation supports both scalar and vector floating-point types (f16, f32).
+    
+    Example:
+    ```mlir
+    // Scalar f32 median
+    %result = rocdl.fmed3 %a, %b, %c : f32
+    
+    // Vector f16 median
+    %result = rocdl.fmed3 %va, %vb, %vc : vector<4xf16>
+    ```
   }];
-  string llvmBuilder = [{
-    $res = createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_fmed3,
-                               {$src0, $src1, $src2},
-                               {moduleTranslation.convertType(op.getRes().getType())});
+  let assemblyFormat = [{
+    $src0 `,` $src1 `,` $src2 attr-dict `:` type($res)
   }];
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 782ef4e154440..55df41128f16a 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -29,6 +29,20 @@ func.func @rocdl_special_regs() -> i32 {
   llvm.return %0 : i32
 }
 
+func.func @rocdl.fmed3.scalar(%a: f32, %b: f32, %c: f32) -> f32 {
+  // CHECK-LABEL: rocdl.fmed3.scalar
+  // CHECK: %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f32
+  %0 = rocdl.fmed3 %a, %b, %c : f32
+  llvm.return %0 : f32
+}
+
+func.func @rocdl.fmed3.vector(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf16>) -> vector<4xf16> {
+  // CHECK-LABEL: rocdl.fmed3.vector
+  // CHECK: %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : vector<4xf16>
+  %0 = rocdl.fmed3 %a, %b, %c : vector<4xf16>
+  llvm.return %0 : vector<4xf16>
+}
+
 func.func @rocdl.barrier() {
   // CHECK: rocdl.barrier
   rocdl.barrier
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 579669f646ceb..d55c53362f6e0 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1298,16 +1298,16 @@ llvm.func @rocdl_last_use(%ptr: !llvm.ptr<1>) -> i32 {
   llvm.return %ret : i32
 }
 
-llvm.func @test_med3_f16(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
-  // CHECK-LABEL: define half @test_med3_f16(half %0, half %1, half %2)
-  %0 = rocdl.med3 %arg0, %arg1, %arg2 : (f16, f16, f16) -> f16
+llvm.func @test_fmed3_f16(%arg0: f16, %arg1: f16, %arg2: f16) -> f16 {
+  // CHECK-LABEL: define half @test_fmed3_f16(half %0, half %1, half %2)
+  %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f16
   llvm.return %0 : f16
   // CHECK: call half @llvm.amdgcn.fmed3.f16(half %0, half %1, half %2)
 }
 
-llvm.func @test_med3_f32(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
-  // CHECK-LABEL: define float @test_med3_f32(float %0, float %1, float %2)
-  %0 = rocdl.med3 %arg0, %arg1, %arg2 : (f32, f32, f32) -> f32
+llvm.func @test_fmed3_f32(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
+  // CHECK-LABEL: define float @test_fmed3_f32(float %0, float %1, float %2)
+  %0 = rocdl.fmed3 %arg0, %arg1, %arg2 : f32
   llvm.return %0 : f32
   // CHECK: call float @llvm.amdgcn.fmed3.f32(float %0, float %1, float %2)
 }



More information about the Mlir-commits mailing list